# Lab. 4 LangGraph + Text2SQL

![Intro1](../images/text2sql/langgraph.png)

### Basic Configuration

In [None]:
!pip install opensearch-py
!pip install langgraph

In [None]:
import boto3
import json
import copy
from botocore.config import Config
from sqlalchemy import create_engine
from src.opensearch import OpenSearchVectorRetriever, OpenSearchClient

boto_session = boto3.Session()
region_name = boto_session.region_name

llm_model = "anthropic.claude-3-5-haiku-20241022-v1:0"
llm_model = "anthropic.claude-3-5-sonnet-20241022-v2:0"

engine = create_engine("sqlite:///Chinook.db")
DIALECT = "sqlite"

def converse_with_bedrock(sys_prompt, usr_prompt):
    temperature = 0.0
    top_p = 0.1
    top_k = 1
    inference_config = {"temperature": temperature, "topP": top_p}
    additional_model_fields = {"top_k": top_k}
    response = boto3_client.converse(
        modelId=llm_model, 
        messages=usr_prompt, 
        system=sys_prompt,
        inferenceConfig=inference_config,
        additionalModelRequestFields=additional_model_fields
    )
    return response['output']['message']['content'][0]['text']

def init_boto3_client(region: str):
    retry_config = Config(
        region_name=region,
        retries={"max_attempts": 10, "mode": "standard"}
    )
    return boto3.client("bedrock-runtime", region_name=region, config=retry_config)

def init_search_resources():  
    
    sql_search_client = OpenSearchClient(region_name=region_name, index_name='example_queries', mapping_name='mappings-sql', vector="input_v", text="input", output=["input", "query"])
    table_search_client = OpenSearchClient(region_name=region_name, index_name='schema_descriptions', mapping_name='mappings-detailed-schema', vector="table_summary_v", text="table_summary", output=["table_name", "table_summary"])

    sql_retriever = OpenSearchVectorRetriever(sql_search_client, region_name=region_name, k=10)
    table_retriever = OpenSearchVectorRetriever(table_search_client, region_name=region_name, k=10)
    return sql_search_client, table_search_client, sql_retriever, table_retriever

def get_column_description(table_name):
    query = {
        "query": {
            "match": {
                "table_name": table_name
            }
        }
    }
    response = table_search_client.conn.search(index=table_search_client.index_name, body=query)

    if response['hits']['total']['value'] > 0:
        source = response['hits']['hits'][0]['_source']
        columns = source.get('columns', [])
        if columns:
            return {col['col_name']: col['col_desc'] for col in columns}
        else:
            return {}
    else:
        return {}

def search_by_keywords(keyword):
    query = {
        "size": 10, 
        "query": {
            "nested": {
                "path": "columns",
                "query": {
                    "match": {
                        "columns.col_desc": f"{keyword}"
                    }
                },
                "inner_hits": {
                    "size": 1, 
                    "_source": ["columns.col_name", "columns.col_desc"]
                }
            }
        },
        "_source": ["table_name"]
    }
    response = table_search_client.conn.search(
        index=table_search_client.index_name,
        body=query
    )
    
    search_result = ""
    try:
        results = []
        table_names = set()  
        if 'hits' in response and 'hits' in response['hits']:
            for hit in response['hits']['hits']:
                table_name = hit['_source']['table_name']
                table_names.add(table_name)  
                for inner_hit in hit['inner_hits']['columns']['hits']['hits']:
                    column_name = inner_hit['_source']['col_name']
                    column_description = inner_hit['_source']['col_desc']
                    results.append({
                        "table_name": table_name,
                        "column_name": column_name,
                        "column_description": column_description
                    })
                    if len(results) >= 5:
                        break
                if len(results) >= 5:
                    break
        search_result += json.dumps(results, ensure_ascii=False)
    except:
        search_result += f"{keyword} not found"
    return search_result    

def create_prompt(sys_template, user_template, **kwargs):
    sys_prompt = [{"text": sys_template.format(**kwargs)}]
    usr_prompt = [{"role": "user", "content": [{"text": user_template.format(**kwargs)}]}]
    return sys_prompt, usr_prompt

boto3_client = init_boto3_client(region_name)
sql_search_client, table_search_client, sql_retriever, table_retriever = init_search_resources()

### Searching Test

In [None]:
question = "Genres people like"

sql_search_result = sql_retriever.vector_search(question)
table_search_result = table_retriever.vector_search(question)

if sql_search_result:
    page_content = json.loads(sql_search_result[0].page_content)
    print("Sample query search result: ", json.dumps(page_content, indent=4, ensure_ascii=False))

if table_search_result:
    page_content = json.loads(table_search_result[0].page_content)
    print("Table search result: ", json.dumps(page_content, indent=4, ensure_ascii=False))

### GraphState Initialization

In [4]:
from typing import TypedDict

class GraphState(TypedDict):
    question: str  
    intent: str
    sample_queries: list
    readiness: str
    tables_summaries: list
    table_names: list
    table_details: list
    query_state: dict
    next_action: str
    answer: str
    dialect: str
    

### SubGraph1 - Schema Linking definition

In [5]:
### Schema Linking - SubGraph1
from sqlalchemy import inspect, text

csv_list_response_format = "Your response should be a list of comma separated values, eg: `foo, bar, baz` or `foo,bar,baz`"
json_response_format = """'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"setup": {"title": "Setup", "description": "question to set up a joke", "type": "string"}, "punchline": {"title": "Punchline", "description": "answer to resolve the joke", "type": "string"}}, "required": ["setup", "punchline"]}\n```'"""

def analyze_intent(state: GraphState) -> GraphState:
    question = state["question"]
    sys_prompt_template = "You are an assistant who understands the intent of user questions. Your task is to classify each user question into one category."
    usr_prompt_template = f"If a database query is needed to answer the user's question, respond with 'database'. Otherwise, respond with 'general'. Skip any preamble. \n\n #Question: {question}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question)
    intent = converse_with_bedrock(sys_prompt, usr_prompt)

    return GraphState(intent=intent)

def get_sample_queries(state: GraphState) -> GraphState:
    question = state["question"]
    samples = sql_retriever.vector_search(question)
    page_contents = [doc.page_content for doc in samples if doc is not None]
    sample_inputs = [json.loads(content)['input'] for content in page_contents]

    sys_prompt_template = "You are a skilled database engineer who writes SQL queries for user questions. Your task is to select sample queries that are useful for creating SQL queries that match the question. The sample queries you select can be used for query reuse, schema reference, etc."    
    usr_prompt_template = "Select sample queries that are useful for writing SQL queries that match the question, and respond with them sorted by importance. Respond only with the index numbers of the sample queries (starting from 0). If there are no relevant samples, respond with an empty list (""). \n\n #Question: {question}\n\n #Sample queries:\n {sample_inputs}\n\n #Format: {csv_list_response_format}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, sample_inputs=sample_inputs, csv_list_response_format=csv_list_response_format)
    sample_ids = converse_with_bedrock(sys_prompt, usr_prompt)
    try:
        if sample_ids == '""' or sample_ids.strip() == "":
            return GraphState(sample_queries=[])
        else:
            sample_ids_list = [int(id.strip()) for id in sample_ids.split(',') if id.strip().isdigit()]
            sample_queries = [json.loads(page_contents[id]) for id in sample_ids_list] if sample_ids_list else []
            return GraphState(sample_queries=sample_queries)
    except:
        return GraphState(sample_queries=[])
    
def check_readiness(state: GraphState) -> GraphState:
    print(state)
    question = state["question"]
    sample_queries = state["sample_queries"]
    table_details = state.get("table_details", "")

    sys_prompt_template = "You are a skilled database engineer who writes SQL queries for user questions. Your task is to determine whether it's possible to write an SQL query for the user's question based on the given database information."
    usr_prompt_template = "Determine if sufficient information has been provided to generate an SQL query for the question. Respond with 'Ready' if there's enough information, or 'Not Ready' if the information is insufficient. \n\n #Question: {question}\n\n #Sample queries:\n {sample_queries}\n\n #Available tables:\n {table_details} \n\n Skip the preamble or explaination. Only provide 'Ready' or 'Not Ready'"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, sample_queries=sample_queries, table_details=table_details)
    readiness = converse_with_bedrock(sys_prompt, usr_prompt)
    
    return GraphState(readiness=readiness)

def get_relevant_tables(state: GraphState) -> GraphState:
    question = state["question"]
    tables = table_retriever.vector_search(question)
    page_contents = [doc.page_content for doc in tables if doc is not None]
    table_inputs = [json.loads(content)['table_summary'] for content in page_contents]

    sys_prompt_template = "You are a skilled database engineer who writes SQL queries to match user requests. Your task is to select the tables needed to write the SQL query."
    usr_prompt_template = "Select the tables needed to generate an SQL query that matches the user's request, sort them by importance, and respond with their index numbers (starting from 0). If there are no tables relevant to the user's request, respond with an empty list ("").\n\n #Question: {question}\n\n #Table information:\n {table_inputs}\n\n #Format: {csv_list_response_format}"
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, table_inputs=table_inputs, csv_list_response_format=csv_list_response_format)
    table_ids = converse_with_bedrock(sys_prompt, usr_prompt)
    try:
        if table_ids == '""' or table_ids.strip() == "":
            return GraphState(tables=[], table_names=[])
        else:
            table_ids_list = [int(id.strip()) for id in table_ids.split(',') if id.strip().isdigit()]
            tables = [json.loads(page_contents[id]) for id in table_ids_list] if table_ids_list else []
            table_names = [table['table_name'] for table in tables]
            return GraphState(tables=tables, table_names=table_names)
    except:
        return GraphState(tables=[], table_names=[])

def describe_schema(state: GraphState) -> GraphState:
    table_names = state["table_names"]
    table_details = []
    inspector = inspect(engine)
    
    for table_name in table_names:
        columns = inspector.get_columns(table_name)

        create_table_sql = f"CREATE TABLE {table_name} (\n"
        create_table_sql += ",\n".join([f"    {col['name']} {col['type']}" for col in columns])
        create_table_sql += "\n);"

        with engine.connect() as connection:
            sample_query = text(f"SELECT * FROM {table_name} LIMIT 5")
            result = connection.execute(sample_query)
            sample_data = [dict(zip(result.keys(), row)) for row in result]
            
        table_desc = get_column_description(table_name) if 'table_search_client' in globals() else {}

        table_detail = {
            "table": table_name,
            "cols": table_desc if table_desc else {col['name']: str(col['type']) for col in columns},
            "create_table_sql": create_table_sql,
            "sample_data": str(sample_data) if sample_data else "No sample data available"
        }

        if not table_detail["cols"]:
            print(f"No columns found for table {table_name}")
        table_details.append(table_detail) 
                    
    return GraphState(table_details=table_details)

def next_step_by_intent(state: GraphState) -> GraphState:
    return state["intent"]

def next_step_by_readiness(state: GraphState) -> GraphState:
    return state["readiness"]


### SubGraph1 - Code modules (Dev)

In [6]:
### Schema Linking - SubGraph1 Modules (Dev)

def analyze_intent_dev(question):
    sys_prompt_template = "You are an assistant who understands the intent of user questions. Your task is to classify each user question into one category."
    usr_prompt_template = f"If a database query is needed to answer the user's question, respond with 'database'. Otherwise, respond with 'general'. Skip any preamble. \n\n #Question: {question}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question)
    intent = converse_with_bedrock(sys_prompt, usr_prompt)

    return intent

def get_general_answer_dev(question):
    sys_prompt_template = "You are a capable assistant who answers general questions from users. If you don't know the answer to a question, admit that you don't know."
    usr_prompt_template = "#Question: {question}"
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question)
    answer = converse_with_bedrock(sys_prompt, usr_prompt)

    return answer    

def get_sample_queries_dev(question):
    samples = sql_retriever.vector_search(question)
    page_contents = [doc.page_content for doc in samples if doc is not None]
    sample_inputs = [json.loads(content)['input'] for content in page_contents]

    sys_prompt_template = "You are a skilled database engineer who writes SQL queries for user questions. Your task is to select sample queries that are useful for creating SQL queries that match the question. The sample queries you select can be used for query reuse, schema reference, etc."    
    usr_prompt_template = "Select sample queries that are useful for writing SQL queries that match the question, and respond with them sorted by importance. Respond only with the index numbers of the sample queries (starting from 0). If there are no relevant samples, respond with an empty list (""). \n\n #Question: {question}\n\n #Sample queries:\n {sample_inputs}\n\n #Format: {csv_list_response_format}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, sample_inputs=sample_inputs, csv_list_response_format=csv_list_response_format)
    sample_ids = converse_with_bedrock(sys_prompt, usr_prompt)
    try:
        if sample_ids == '""' or sample_ids.strip() == "":
            return []
        else:
            sample_ids_list = [int(id.strip()) for id in sample_ids.split(',') if id.strip().isdigit()]
            sample_queries = [json.loads(page_contents[id]) for id in sample_ids_list] if sample_ids_list else []
            return sample_queries
    except:
        return []
    
def check_readiness_dev(question, sample_queries, table_details):
    sys_prompt_template = "You are a skilled database engineer who writes SQL queries for user questions. Your task is to determine whether it's possible to write an SQL query for the user's question based on the given database information."
    usr_prompt_template = "Determine if sufficient information has been provided to generate an SQL query for the question. Skip any preamble and respond with 'Ready' if there's enough information, or 'Not Ready' if the information is insufficient.\n\n #Question: {question}\n\n #Sample queries:\n {sample_queries}\n\n #Available tables:\n {table_details}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, sample_queries=sample_queries, table_details=table_details)
    readiness = converse_with_bedrock(sys_prompt, usr_prompt)
    
    return readiness

def get_relevant_tables_dev(question):
    tables = table_retriever.vector_search(question)
    page_contents = [doc.page_content for doc in tables if doc is not None]
    table_inputs = [json.loads(content)['table_summary'] for content in page_contents]
    sys_prompt_template = "You are a skilled database engineer who writes SQL queries to match user requests. Your task is to select the tables needed to write the SQL query."
    usr_prompt_template = "Select the tables needed to generate an SQL query that matches the user's request, sort them by importance, and respond with their index numbers (starting from 0). If there are no tables relevant to the user's request, respond with an empty list ("").\n\n #Question: {question}\n\n #Table information:\n {table_inputs}\n\n #Format: {csv_list_response_format}"
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, table_inputs=table_inputs, csv_list_response_format=csv_list_response_format)
    table_ids = converse_with_bedrock(sys_prompt, usr_prompt)
    try:
        if table_ids == '""' or table_ids.strip() == "":
            return []
        else:
            table_ids_list = [int(id.strip()) for id in table_ids.split(',') if id.strip().isdigit()]
            tables = [json.loads(page_contents[id]) for id in table_ids_list] if table_ids_list else []
            table_names = [table['table_name'] for table in tables]
            return tables, table_names
    except:
        return []

def describe_schema_dev(table_names):
    table_details = []
    inspector = inspect(engine)
    
    for table_name in table_names:
        columns = inspector.get_columns(table_name)

        create_table_sql = f"CREATE TABLE {table_name} (\n"
        create_table_sql += ",\n".join([f"    {col['name']} {col['type']}" for col in columns])
        create_table_sql += "\n);"

        with engine.connect() as connection:
            sample_query = text(f"SELECT * FROM {table_name} LIMIT 5")
            result = connection.execute(sample_query)
            sample_data = [dict(zip(result.keys(), row)) for row in result]
            
        table_desc = get_column_description(table_name) if 'table_search_client' in globals() else {}

        table_detail = {
            "table": table_name,
            "cols": table_desc if table_desc else {col['name']: str(col['type']) for col in columns},
            "create_table_sql": create_table_sql,
            "sample_data": str(sample_data) if sample_data else "No sample data available"
        }

        if not table_detail["cols"]:
            print(f"No columns found for table {table_name}")
        table_details.append(table_detail)    
        
    return table_details

### SubGraph1 - Code modules (Test)

In [7]:
### Schema Linking - SubGraph1 Modules (Test)

question1 = "What are the top 10 countries by sales in 2022?"
question2 = "What's the weather like today?"
question3 = "Who are the top 10 customers based on purchase quantity?"

#===================================================================================================

#1 - analyze_intent
print(analyze_intent_dev(question1))
print(analyze_intent_dev(question2))

#2 - get_sample_queries
print(get_sample_queries_dev(question1))
print(get_sample_queries_dev(question3))

#3 - check_readiness
sample_queries = [
   '{"input": "What are the details of customers residing in Canada?", "query": "SELECT * FROM Customer WHERE Country = \'Canada\'"}',
   '{"input": "How many tracks are on the album with ID 5?", "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5"}',
   '{"input": "What are the top 5 customers by total purchase amount?", "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5"}'
]
print(check_readiness_dev(question3, sample_queries, table_details=""))

#4 - get_relevant_tables
print(get_relevant_tables_dev(question1))
print(get_relevant_tables_dev(question3))

#5 - describe_schema
table_names = ['Invoice']
print(describe_schema_dev(table_names))

#6 - check_readiness
table_names = ['Invoice', 'Customer']
table_details = describe_schema_dev(table_names)
print(check_readiness_dev(question1, sample_queries="", table_details=table_details))

# #===================================================================================================

### SubGraph2 - Text2SQL Process

In [8]:
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

Session = sessionmaker(bind=engine)

### Query Generation & Execution - SubGraph2

initial_query_state = {
    "status": "success",
    "query": "",
    "result": "",
    "error": {
        "code": "",
        "message": "",
        "failed_step": "",
        "hint": ""
    }
}

def generate_query(state: GraphState) -> GraphState:
    dialect = DIALECT
    new_query_state = copy.deepcopy(initial_query_state)
    question = state["question"]
    sample_queries = state["sample_queries"]
    table_details = state["table_details"]

    query_state = state.get("query_state", {}) or {}
    error_info = query_state.get("error", {}) or {}
    hint = error_info.get("hint", "None")
    
    sys_prompt_template = "You are a skilled database engineer who writes {dialect} SQL queries in response to user questions. Your task is to create accurate SQL queries that match the user's question based on the given database information."
    usr_prompt_template = "Based on the following sample queries, schema information, and past failure history, create a query that matches the DB dialect. Skip the introduction and provide only the generated SQL query statement. \n\n #Question: {question}\n\n #Sample queries:\n {sample_queries}\n\n #Available tables:\n {table_details}\n\n #Additional information (past failure history, additional acquired information, etc.):\n {hint}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, dialect=dialect, sample_queries=sample_queries, table_details=table_details, hint=hint)
    generated_query = converse_with_bedrock(sys_prompt, usr_prompt)

    new_query_state["query"] = generated_query

    return GraphState(query_state=new_query_state)

def validate_query(state: GraphState) -> GraphState:
    dialect = DIALECT
    question = state["question"]
    query_state = copy.deepcopy(state["query_state"])
    query = query_state["query"]
    
    explain_statements = {
        'mysql': "EXPLAIN {query}",
        'mariadb': "EXPLAIN {query}",
        'sqlite': "EXPLAIN QUERY PLAN {query}",
        'oracle': "EXPLAIN PLAN FOR\n{query}\n\nSELECT * FROM TABLE(DBMS_XPLAN.DISPLAY);",
        'postgresql': "EXPLAIN ANALYZE {query}",
        'postgres': "EXPLAIN ANALYZE {query}",
        'presto': "EXPLAIN ANALYZE {query}",
        'sqlserver': "SET STATISTICS PROFILE ON; {query}; SET STATISTICS PROFILE OFF;"
    }
    
    if dialect.lower() not in explain_statements:
        query_plan = " "
    else:
        try:
            explain_query = explain_statements[dialect.lower()].format(query=query)
            with Session() as session:
                result = session.execute(text(explain_query))
                query_plan = "\n".join([str(row) for row in result])
        except Exception as e:
            query_state["status"] = "error"
            query_state["error"]["code"] = "E01"
            query_state["error"]["message"] = f"An error occurred while executing the EXPLAIN query: {str(e)}"
            query_state["error"]["failed_step"] = "validation"
            query_state["query"] = query
            return GraphState(query_state=query_state)

    sys_prompt_template = "You are a database expert who reviews existing {dialect} SQL queries in response to user questions and optimizes them when necessary. Your task is to examine the query's coherence and potential for optimization based on the given SQL query and additional information, and provide a final query based on this analysis." 
    usr_prompt_template = "Please add aliases to the query to match the user's question. It is not allowed to add tables or columns that were not used in the original SQL query. Skip the introduction and provide only the generated SQL query statement. \n\n #Question: {question}\n\n #Existing query:\n {query}\n\n #Query plan:\n {query_plan}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, dialect=dialect, query=query, query_plan=query_plan)
    validated_query = converse_with_bedrock(sys_prompt, usr_prompt)
    query_state["query"] = validated_query

    return GraphState(query_state=query_state)

def execute_query(state: GraphState) -> GraphState:
    query_state = copy.deepcopy(state["query_state"])
    query = query_state["query"]
    try:
        with Session() as session:
            result = session.execute(text(query))
            query_state["result"] = "\n".join([str(row) for row in result])
    except Exception as e:
        query_state["status"] = "error"
        query_state["error"]["code"] = "E02"
        query_state["error"]["message"] = f"An error occurred while executing the validated query: {str(e)}"
        query_state["error"]["failed_step"] = "execution"
        return GraphState(query_state=query_state)
    return GraphState(query_state=query_state)
    
def handle_failure(state: GraphState) -> GraphState:
    query_state = copy.deepcopy(state["query_state"])
    query = query_state['query']
    message = query_state['error']['message']
    sys_prompt_template = "You are a skilled database engineer who handles SQL query failures. Your task is to identify the cause of failure for the given SQL query and determine the next steps for problem resolution."
    usr_prompt_template = "Based on the failure message of the given SQL query, provide one of the following causes (`failure_type`) along with a clue for resolution (`hint`).\nHere are examples of failure_type choices:\nInaccurate query syntax: `syntax_check`\nSchema mismatch: `schema_check`\nExternal DB factors (permissions, connection issues, etc.): `stop`\nTemporary DB malfunction (query re-execution needed): `retry`\n\n#Failed query: {query}\n\n#Failure message: {message}\n\n#Format: {json_response_format} Skip the preamble and only provide the valid JSON document."
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, query=query, message=message, json_response_format=json_response_format)
    result = converse_with_bedrock(sys_prompt, usr_prompt)
    json_result = json.loads(result)

    query_state["hint"] = json_result["hint"]
    return GraphState(next_action=json_result["failure_type"], query_state=query_state)

def get_relevant_columns(state: GraphState) -> GraphState:
    query_state = copy.deepcopy(state["query_state"])
    question = state["question"]
    query = query_state["query"]
    message = query_state['error']['message']
    sys_prompt_template = "You are an expert SQL query troubleshooter. Your task is to analyze failed queries and suggest relevant keywords for schema exploration to resolve the issue."
    usr_prompt_template = """Given a user question, a failed SQL query, and an error message, provide 3-5 relevant keywords or phrases for database schema exploration. These should help in finding the correct table and column names to fix the query.\n\n#User question: {question}\n\n#Failed query:\n{query}\n\n#Error message:\n{message}\n\nRespond only with a comma-separated list of keywords or short phrases, without any additional text or explanation.\n\n#Format: {csv_list_response_format}"""
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, message=message, csv_list_response_format=csv_list_response_format)
    keywords = converse_with_bedrock(sys_prompt, usr_prompt)
    return keywords

def next_step_by_query_state(state:GraphState) -> GraphState:
    return state["query_state"]["status"]

def next_step_by_next_action(state:GraphState) -> GraphState:
    return state["next_action"]


### SubGraph2 - Code modules (Dev)

In [9]:
def generate_query_dev(question, dialect, sample_queries, table_details, hint):
    dialect = dialect
    query_state = copy.deepcopy(initial_query_state)
    question = question
    sample_queries = sample_queries
    table_details = table_details
    hint = hint
    
    sys_prompt_template = "You are a skilled database engineer who writes {dialect} SQL queries in response to user questions. Your task is to create accurate SQL queries that match the user's question based on the given database information."
    usr_prompt_template = "Based on the following sample queries, schema information, and past failure history, create a query that matches the DB dialect. Skip the introduction and provide only the generated SQL query statement. \n\n #Question: {question}\n\n #Sample queries:\n {sample_queries}\n\n #Available tables:\n {table_details}\n\n #Additional information (past failure history, additional acquired information, etc.):\n {hint}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, dialect=dialect, sample_queries=sample_queries, table_details=table_details, hint=hint)
    generated_query = converse_with_bedrock(sys_prompt, usr_prompt)

    query_state["query"] = generated_query
    return copy.deepcopy(query_state)

def validate_query_dev(question, dialect, query_state):
    question = question
    query_state = copy.deepcopy(query_state)
    query = query_state["query"]
    
    explain_statements = {
        'mysql': "EXPLAIN {query}",
        'mariadb': "EXPLAIN {query}",
        'sqlite': "EXPLAIN QUERY PLAN {query}",
        'oracle': "EXPLAIN PLAN FOR\n{query}\n\nSELECT * FROM TABLE(DBMS_XPLAN.DISPLAY);",
        'postgresql': "EXPLAIN ANALYZE {query}",
        'postgres': "EXPLAIN ANALYZE {query}",
        'presto': "EXPLAIN ANALYZE {query}",
        'sqlserver': "SET STATISTICS PROFILE ON; {query}; SET STATISTICS PROFILE OFF;"
    }
    
    if dialect.lower() not in explain_statements:
        query_plan = " "
    else:
        try:
            explain_query = explain_statements[dialect.lower()].format(query=query)
            query_plan = db.run(explain_query)
        except Exception as e:
            query_state["status"] = "error"
            query_state["error"]["code"] = "E01"
            query_state["error"]["message"] = f"An error occurred while executing the EXPLAIN query: {str(e)}"
            query_state["error"]["failed_step"] = "validation"
            query_state["query"] = query
            return copy.deepcopy(query_state)

    sys_prompt_template = "You are a database expert who reviews existing {dialect} SQL queries in response to user questions and optimizes them when necessary. Your task is to examine the query's coherence and potential for optimization based on the given SQL query and additional information, and provide a final query based on this analysis." 
    usr_prompt_template = "Please add aliases to the query to match the user's question. It is not allowed to add tables or columns that were not used in the original SQL query. Skip the introduction and provide only the generated SQL query statement. \n\n #Question: {question}\n\n #Existing query:\n {query}\n\n #Query plan:\n {query_plan}"    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, dialect=dialect, query=query, query_plan=query_plan)
    validated_query = converse_with_bedrock(sys_prompt, usr_prompt)
    query_state["query"] = validated_query

    return copy.deepcopy(query_state)

def execute_query_dev(query_state):
    query = query_state["query"]
    try:
        query_state["result"] = db.run(query)
    except Exception as e:
        query_state["status"] = "error"
        query_state["error"]["code"] = "E02"
        query_state["error"]["message"] = f"An error occurred while executing the validated query: {str(e)}"
        query_state["error"]["failed_step"] = "execution"
        return copy.deepcopy(query_state)
    return copy.deepcopy(query_state)
    
def handle_failure_dev(query_state):
    query = query_state['query']
    message = query_state['error']['message']
    sys_prompt_template = "You are a skilled database engineer who handles SQL query failures. Your task is to identify the cause of failure for the given SQL query and determine the next steps for problem resolution."
    usr_prompt_template = "Based on the failure message of the given SQL query, provide one of the following causes (`failure_type`) along with a clue for resolution (`hint`).\nHere are examples of failure_type choices:\nInaccurate query syntax: `syntax_check`\nSchema mismatch: `schema_check`\nExternal DB factors (permissions, connection issues, etc.): `stop`\nTemporary DB malfunction (query re-execution needed): `retry`\n\n#Failed query: {query}\n\n#Failure message: {message}\n\n#Format: {json_response_format} Skip the preamble and only provide the valid JSON document."    
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, query=query, message=message, json_response_format=json_response_format)
    result = converse_with_bedrock(sys_prompt, usr_prompt)
    json_result = json.loads(result)

    query_state["hint"] = json_result.get("hint", "None")
    failure_type = json_result.get("failure_type", "Invalid Response")
    return failure_type, copy.deepcopy(query_state)

def get_relevant_columns_dev(question, query_state):
    query_state = copy.deepcopy(query_state)
    question = question
    query = query_state["query"]
    message = query_state['error']['message']
    sys_prompt_template = "You are an expert SQL query troubleshooter. Your task is to analyze failed queries and suggest relevant keywords for schema exploration to resolve the issue."
    usr_prompt_template = """Given a user question, a failed SQL query, and an error message, provide 3-5 relevant keywords or phrases for database schema exploration. These should help in finding the correct table and column names to fix the query.\n\n#User question: {question}\n\n#Failed query:\n{query}\n\n#Error message:\n{message}\n\nRespond only with a comma-separated list of keywords or short phrases, without any additional text or explanation.\n\n#Format: {csv_list_response_format}"""
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, message=message, csv_list_response_format=csv_list_response_format)
    response = converse_with_bedrock(sys_prompt, usr_prompt)
    
    keyword_list = [keyword.strip() for keyword in response.strip('`').split(',')]
    print("search keyword:", keyword_list)
    query_state["hint"] += "\n\n#Additional Schema:\n"
    for keyword in keyword_list:
        query_state["hint"] += search_by_keywords(keyword)
        query_state["hint"] += "\n"

    return copy.deepcopy(query_state)
    

### SubGraph2 - Code modules (Test)

In [10]:
#===================================================================================================

#7 - generate_query
dialect = "sqlite"
question = "What are the top 10 countries by sales in 2022?"
sample_queries =  [{'input': 'What are the top 5 customers by total purchase amount?',
  'query': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5'}]
table_details = [{'table': 'Customer',
  'cols': {'CustomerId': 'Primary key, unique customer identifier.',
   'FirstName': 'First name of the customer.',
   'LastName': 'Last name of the customer.',
   'Company': 'Company of the customer.',
   'Address': 'Address of the customer.',
   'City': 'City of the customer.',
   'State': 'State of the customer.',
   'Country': 'Country of the customer.',
   'PostalCode': 'Postal code of the customer.',
   'Phone': 'Phone number of the customer.',
   'Fax': 'Fax number of the customer.',
   'Email': 'Email address of the customer.',
   'SupportRepId': 'Foreign key that references the employee who supports this customer.'},
  'create_table_sql': 'CREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)',
  'sample_data': '3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\tluisg@embraer.com.br\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\tleonekohler@surfeu.de\t5\n3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\tftremblay@gmail.com\t3'},
 {'table': 'Invoice',
  'cols': {'InvoiceId': 'Primary key, unique identifier for the invoice.',
   'CustomerId': 'Foreign key that references the customer associated with this invoice.',
   'InvoiceDate': 'Date when the invoice was issued.',
   'BillingAddress': 'Billing address on the invoice.',
   'BillingCity': 'Billing city on the invoice.',
   'BillingState': 'Billing state on the invoice.',
   'BillingCountry': 'Billing country on the invoice.',
   'BillingPostalCode': 'Billing postal code on the invoice.',
   'Total': 'Total amount of the invoice.'},
  'create_table_sql': 'CREATE TABLE "Invoice" (\n\t"InvoiceId" INTEGER NOT NULL, \n\t"CustomerId" INTEGER NOT NULL, \n\t"InvoiceDate" DATETIME NOT NULL, \n\t"BillingAddress" NVARCHAR(70), \n\t"BillingCity" NVARCHAR(40), \n\t"BillingState" NVARCHAR(40), \n\t"BillingCountry" NVARCHAR(40), \n\t"BillingPostalCode" NVARCHAR(10), \n\t"Total" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("InvoiceId"), \n\tFOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")\n)',
  'sample_data': '3 rows from Invoice table:\nInvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n1\t2\t2021-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n2\t4\t2021-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n3\t8\t2021-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94'}]

print(generate_query_dev(question=question, dialect=dialect, sample_queries=sample_queries, table_details=table_details, hint="None"))

#8 - validate_query
question = "What are the top 10 countries by sales in 2022?"
query_state = {'status': 'success',
 'query': "SELECT BillingCountry, SUM(Total) AS TotalSales\nFROM Invoice\nWHERE InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY BillingCountry\nORDER BY TotalSales DESC\nLIMIT 10;",
 'result': '',
 'error': {'code': '', 'message': '', 'failed_step': '', 'hint': 'None'}}

print(validate_query_dev(question, dialect, query_state))

#9 - execute_query
query_state = {'status': 'success',
 'query': "SELECT BillingCountry AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;",
 'result': '',
 'error': {'code': '', 'message': '', 'failed_step': '', 'hint': 'None'}}

print(execute_query_dev(query_state))

#10 - handle_failure
query_state_type1 = {'status': 'error',
 'query': "SELECT Billing AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;",
 'result': '',
 'error': {'code': 'E02',
  'message': "An error occurred while executing the validated query: (sqlite3.OperationalError) no such column: Billing\n[SQL: SELECT Billing AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)",
  'failed_step': 'execution',
  'hint': "None"}}
handle_failure_dev(query_state_type1)

query_state_type2 = {'status': 'error',
 'query': "SELECTS BillingCountry AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;",
 'result': '',
 'error': {'code': 'E02',
  'message': 'An error occurred while executing the validated query: (sqlite3.OperationalError) near "SELECTS": syntax error\n[SQL: SELECTS BillingCountry AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN \'2022-01-01\' AND \'2022-12-31\'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)',
  'failed_step': 'execution',
  'hint': 'None'}}
handle_failure_dev(query_state_type2)

# #11 - get_relevant_column
question = "What are the top 10 countries by sales in 2022?"
query_state = {'status': 'error',
  'query': "SELECT Country AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;",
  'result': '',
  'error': {'code': 'E02',
   'message': "An error occurred while executing the validated query: (sqlite3.OperationalError) no such column: Country\n[SQL: SELECT Billing AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;]\n(Background on this error at: https://sqlalche.me/e/20/e3q8)",
   'failed_step': 'execution',
   'hint': 'None'},
  'hint': "The column name 'Country' does not exist in the Invoice table."}

get_relevant_columns_dev(question, query_state)

### Answer generation nodes

In [11]:
def get_general_answer(state: GraphState) -> GraphState:
    question = state["question"]
    sys_prompt_template = "You are a capable assistant who answers general questions from users. If you don't know the answer to a question, admit that you don't know."
    usr_prompt_template = "#Question: {question}"
    sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question)
    answer = converse_with_bedrock(sys_prompt, usr_prompt)

    return GraphState(answer=answer)

def get_database_answer(state: GraphState) -> GraphState:
    question = state["question"]
    query_state = state["query_state"]
    query = query_state["query"]
    data = query_state["result"]
    failed_step = query_state["error"]["failed_step"]
    message = query_state["error"]["message"]
    sys_prompt_template = "You are a competent assistant who answers user questions based on database information. Your task is to provide thorough answers to user questions, referencing the given information."
    
    if query_state["status"] == "success":
        usr_prompt_template = "The answer should include the used query, dataframe (as a Markdown Table), and a brief response to the question. \n\n#Question: {question}\n\n#Used query: {query}\n\n#Data: {data}\n\n"
        sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, data=data)
    else:
        usr_prompt_template = "The following is a record of a failed query execution for a user question. Based on this, explain why the request processing failed.\n\n#Question: {question}\n\n#Used query: {query}\n\n#Failed step: {failed_step}\n\n#Error message: {message}\n\n"
        sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, failed_step=failed_step, message=message)    
        
    answer = converse_with_bedrock(sys_prompt, usr_prompt)
    return GraphState(answer=answer)


### Answer generation - Code modules (Dev)

In [12]:
def get_database_answer_dev(question, query_state) -> GraphState:
    question = question
    query_state = query_state
    query = query_state["query"]
    data = query_state["result"]
    failed_step = query_state["error"]["failed_step"]
    message = query_state["error"]["message"]
    sys_prompt_template = "You are a competent assistant who answers user questions based on database information. Your task is to provide thorough answers to user questions, referencing the given information."
    
    if query_state["status"] == "success":
        usr_prompt_template = "The answer should include the used query, dataframe (as a Markdown Table), and a brief response to the question. \n\n#Question: {question}\n\n#Used query: {query}\n\n#Data: {data}\n\n"
        sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, data=data)
    else:
        usr_prompt_template = "The following is a record of a failed query execution for a user question. Based on this, explain why the request processing failed.\n\n#Question: {question}\n\n#Used query: {query}\n\n#Failed step: {failed_step}\n\n#Error message: {message}\n\n"
        sys_prompt, usr_prompt = create_prompt(sys_prompt_template, usr_prompt_template, question=question, query=query, failed_step=failed_step, message=message)    
        
    answer = converse_with_bedrock(sys_prompt, usr_prompt)
    return GraphState(answer=answer)

### 답변 생성 - 검증

In [13]:
question = "What are the top 10 countries by sales in 2022?"
query_state = {'status': 'success',
 'query': "SELECT BillingCountry AS Country, SUM(Total) AS TotalSales\nFROM Invoice i\nWHERE i.InvoiceDate BETWEEN '2022-01-01' AND '2022-12-31'\nGROUP BY Country\nORDER BY TotalSales DESC\nLIMIT 10;",
 'result': "[('USA', 102.98), ('Canada', 76.26), ('Brazil', 41.6), ('France', 39.6), ('Hungary', 32.75), ('United Kingdom', 30.689999999999998), ('Austria', 27.77), ('Germany', 25.74), ('Chile', 17.91), ('India', 17.83)]",
 'error': {'code': '', 'message': '', 'failed_step': '', 'hint': 'None'}}

answer = get_database_answer_dev(question, query_state)
print(answer)

### LangGraph 워크플로 그래프 생성 

In [17]:
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.memory import MemorySaver

workflow = StateGraph(GraphState)

# Global Nodes
workflow.add_node("analyze_intent", analyze_intent)
workflow.add_node("get_general_answer", get_general_answer)
workflow.add_node("get_database_answer", get_database_answer)
workflow.set_entry_point("analyze_intent")

# SubGraph1 Nodes - Schema Linking
workflow.add_node("get_sample_queries", get_sample_queries)
workflow.add_node("check_readiness", check_readiness)
workflow.add_node("get_relevant_tables", get_relevant_tables)
workflow.add_node("describe_schema", describe_schema)

# SubGraph2 Nodes - Query Generation & Execution
workflow.add_node("generate_query", generate_query)
workflow.add_node("validate_query", validate_query)
workflow.add_node("execute_query", execute_query)
workflow.add_node("handle_failure", handle_failure)
workflow.add_node("get_relevant_columns", get_relevant_columns)

# Edge from Entry to SubGraph1
workflow.add_conditional_edges(
    "analyze_intent",
    next_step_by_intent,
    {
        "database": "get_sample_queries",
        "general": "get_general_answer",
    }
)

# Edges in SubGraph1
workflow.add_edge("get_sample_queries", "check_readiness")
workflow.add_conditional_edges(
    "check_readiness"    ,
    next_step_by_readiness,
    {
        "Ready": "generate_query",
        "Not Ready": "get_relevant_tables"
    }
)
workflow.add_edge("get_relevant_tables", "describe_schema")
workflow.add_edge("describe_schema", "check_readiness")

# Edges in SubGraph2
workflow.add_edge("generate_query", "validate_query")
workflow.add_conditional_edges(
    "validate_query"    ,
    next_step_by_query_state,
    {
        "success": "execute_query",
        "error": "handle_failure"
    }
)
workflow.add_conditional_edges(
    "execute_query"    ,
    next_step_by_query_state,
    {
        "success": "get_database_answer",
        "error": "handle_failure"
    }
)
workflow.add_conditional_edges(
    "handle_failure"    ,
    next_step_by_next_action,
    {
        "schema_check": "get_relevant_columns",
        "syntax_check": "generate_query",
        "retry": "validate_query",
        "stop": "get_database_answer"
    }
)
workflow.add_edge("get_relevant_columns", "generate_query")

# Edges to END
workflow.add_edge("get_general_answer", END)
workflow.add_edge("get_database_answer", END)

memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

### LangGraph Workflow - Graph Visualization

In [None]:
from IPython.display import Image, display

try:
    display(
        Image(app.get_graph(xray=True).draw_mermaid_png())
    )  
except:
    pass

### LangGraph Workflow End-to-End Test

In [None]:
import pprint
from langgraph.errors import GraphRecursionError
from langchain_core.runnables import RunnableConfig

config = RunnableConfig(recursion_limit=100, configurable={"thread_id": "TODO"})
inputs = GraphState(question="What are the top 10 countries by sales in 2022?")

pp = pprint.PrettyPrinter(width=200, compact=True)

try:
    for output in app.stream(inputs, config=config):
        for key, value in output.items():
            print(f"\n🔹 [NODE] {key}")
            print("=" * 80)
            for k, v in value.items():
                print(f"📌 {k}:")
                pp.pprint(v)
            print("=" * 80)
except GraphRecursionError as e:
    print(f"⚠️ Recursion limit reached: {e}")

In [None]:
import pprint
from langgraph.errors import GraphRecursionError
from langchain_core.runnables import RunnableConfig

config = RunnableConfig(recursion_limit=100, configurable={"thread_id": "TODO"})
inputs = GraphState(question="What is AWS?")

pp = pprint.PrettyPrinter(width=200, compact=True)

try:
    for output in app.stream(inputs, config=config):
        for key, value in output.items():
            print(f"\n🔹 [NODE] {key}")
            print("=" * 80)
            for k, v in value.items():
                print(f"📌 {k}:")
                pp.pprint(v)
            print("=" * 80)
except GraphRecursionError as e:
    print(f"⚠️ Recursion limit reached: {e}")