## **Text2sql in BigQuery**
#### based on Google's Open Data Qna architecture 

> This notebook is a LangGraph version that can be adapted to use in production

![architecture](../OpenDataQnA_architecture.png)

-------------

In [248]:
import sys, os
sys.path.append(os.path.abspath(".."))
import time
from dotenv import load_dotenv
load_dotenv()
from config.gcp import initialize_gcs_client
initialize_gcs_client()
from utils.const import MODEL_CONFIG, MAXIMUM_DEBUG

### Initialize Graph State

In [249]:
from typing import TypedDict
from pandas import DataFrame

In [250]:
class AgentState(TypedDict):
    original_question: str
    rewrite_question: str
    error_message: str
    debugging_count: int
    query: str
    related_columns: list
    related_query: list
    result_response: str
    result_df: DataFrame | str
    formatted_history: list
    chart_suggestions: dict[str, str]
    chart_js_codes: list
    chart_couter: int
    text2sql_time: float
    visualization_time: float

--------

### **Nodes**

In [251]:
def setup(state: AgentState) -> AgentState:
    """A node for setting up the variables"""
    state["chart_js_codes"] = []
    state["chart_couter"] = 0
    
    return state

#### <mark>**1. Rewrite question Agent**</mark>

This node is an agent that rewrites and refines the question based on the previous questions that has been asked to produce a question that is based on a history.

In [252]:
from src.llm import init_agent_model
from pydantic import BaseModel, Field

In [253]:
# Output Parser Schema
class RewriteQuestionParser(BaseModel):
    rewrite_question: str = Field(description="The rewrite question")

In [254]:
SYSTEM_PROMPT1: str = """
Your main objective is to rewrite and refine the question based on the previous questions that has been asked.

Refine the given question using the provided questions history to produce a standalone question with full context. The refined question should be self-contained, requiring no additional context for answering it.

Make sure all the information is included in the re-written question. You just need to respond with the re-written question.

Below is the previous questions history:

{formatted_history}

Your output should be in this format. 
{format_instructions}
"""

HUMAN_PROMPT1: str = """
Question to rewrite:
{question}
"""

PROMPT_TEMPLATE1: list[tuple] = [
    ('system', SYSTEM_PROMPT1),
    ('human', HUMAN_PROMPT1)
]

In [None]:
def rewrite_question_agent_node(state: AgentState) -> AgentState:
    """Rewrite question by an agent"""
    state["text2sql_time"] = time.time()
    print("\033[92m--- Rewrite question ---\033[00m")
    
    chain = init_agent_model(prompt=PROMPT_TEMPLATE1, model_config=MODEL_CONFIG, parser=RewriteQuestionParser)
    
    response = chain.invoke({
        "formatted_history": state.get("formatted_history", []), # If there is no "formatted_history" then use an empty string
        "question": state["original_question"]
    })
    
    # Set rewrite_question state to be the response from an agent
    state["rewrite_question"] = response["rewrite_question"]
    return state
    
# if __name__ == '__main__' :

#### <mark>**2. Search Database Column**</mark>

This node is for searching database column to be a context for llm to produce a sql

In [256]:
from src.vector import similar_search

In [None]:
def search_column_description_node(state: AgentState) -> AgentState:
    """Search column description"""
    print("\033[92m--- Search column description ---\033[00m")
    similar_search_result = similar_search(vectore_db="mro_db_column_description", search_input=state["rewrite_question"])
    
    # Set related_columns state to be the result from the column description vector result
    return {
        "related_columns": similar_search_result
    }
    
# if __name__ == '__main__' :
#     print(search_column_description_node({'rewrite_question' : 'Who are you'}))

#### <mark>**3. Related Query**</mark>

This node is for providing an example of sql query using in write sql and fix sql.

In [None]:

def search_query_node(state: AgentState) -> AgentState:
    print("\033[92m--- Search query ---\033[00m")
    return {
        "related_query" : [
        """
        question : Can I get the latest report files in July? 
        query : SELECT Files FROM `cbm-cgs-uiim-prd.mro_demo.mrodata` WHERE PARSE_DATE('%d-%b-%y', Datess) BETWEEN '2025-07-01' AND '2025-07-31' ORDER BY PARSE_DATE('%d-%b-%y', Datess) DESC LIMIT 3"""
        ]}

#### <mark>**4. Write SQL agent**</mark>

This node is an agent that writes SQL based on user question.

In [259]:
class WriteSQLParser(BaseModel):
    query: str = Field(description="Generated query (The answer of the SQL query)")

In [260]:
SYSTEM_PROMPT2: str = """
You are an Bigquery SQL guru. Your task is to write a Bigquery SQL query that answers the following question while using the provided context.

<Guidelines>
- Join as minimal tables as possible.
- When joining tables ensure all join columns are the same data_type.
- Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
- Use always SAFE_CAST. If performing a SAFE_CAST, use only Bigquery supported datatypes. (i.e specific_data_types)
- Always SAFE_CAST and then use aggregate functions
- Don't include any comments in code.
- Remove ```sql and ``` from the output and generate the SQL in single line.
- Tables should be refered to using a fully qualified name with enclosed in ticks (`) e.g. `project_id.owner.table_name`.
- Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
- Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
- Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
- Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
- Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
- Table names are case sensitive. DO NOT uppercase or lowercase the table names.
- Always enclose subqueries and union queries in brackets.
- Refer to the examples provided below, if given. 
- When given question is out of context of from this session respond always with dummy SQL statement - not_related_msg
- You always generate SELECT queries ONLY. If asked for other statements for DELETE or MERGE etc respond with dummy SQL statement - not_related_msg
</Guidelines>

Your output should be in this format. 
{format_instructions}
"""

HUMAN_PROMPT2: str = """
<User's question>
{rewrite_question}
</User's question>

<Examples>
{related_query}
</Examples>

<Table Schema>
table : `cbm-cgs-uiim-prd.mro_demo.mrodata`
</Table Schema>

<Columns Schema>
{related_columns}
</Columns Schema>
"""

PROMPT_TEMPLATE2: list[tuple]= [
    ('system', SYSTEM_PROMPT2),
    ('human', HUMAN_PROMPT2)
]

In [None]:
def write_sql_agent_node(state: AgentState) -> AgentState:
    """Write SQL by an agent"""
    print("\033[92m--- Write SQL ---\033[00m")
    
    chain = init_agent_model(prompt=PROMPT_TEMPLATE2, model_config=MODEL_CONFIG, parser=WriteSQLParser)
    
    response = chain.invoke({
        "rewrite_question": state["rewrite_question"], # If there is no "formatted_history" then use an empty string
        "related_query": state["related_query"],
        "related_columns": state["related_columns"]
    })
    
    print(response["query"])
    
    # Set query (sql query) state to be the response from an agent
    return {
        "query": response["query"]
    }

#### <mark>**5. Check SQL syntax node**</mark>

This node checks SQL syntax from the "write sql agent node" and also has to set the error message state.

In [262]:
from src.bq import GcpBigQuery

In [263]:
# Big Query Environments -------
project_id: str = os.getenv("PROJECT_ID")
location: str = os.getenv("LOCATION")
dataset_id: str = os.getenv("DATASET_ID")
table_name: str = "mrodata"
# ------------------------------

# Initiate database connection -
big_query = GcpBigQuery(
    project_id=project_id,
    location=location,
    dataset_id=dataset_id
)

In [None]:

def check_sql_node(state: AgentState) -> AgentState:
    """Check SQL by execute it"""
    print("\033[92m--- Check SQL ---\033[00m")
    error_message: str | None
    debugging_count: int
    
    try:
        _ = big_query.execute_query(state["query"])
        error_message = None
        debugging_count = state.get('debugging_count', 0)
    except Exception as e:
        error_message = str(e)
        debugging_count = state.get('debugging_count', 0) + 1
        
    # Set error_message and debugging_count states for checking edge
    return {
        "error_message" : error_message,
        "debugging_count" : debugging_count
    }

#### <mark>**6. Run SQL syntax node**</mark>

This node runs SQL syntax that there's no error.

In [None]:
def run_sql_node(state: AgentState) -> AgentState:
    """Run SQL"""
    print("\033[92m--- Run SQL ---\033[00m")
    result = big_query.execute_query(state["query"])
    
    # Set result_df state
    return {
        "result_df": result
    }

#### <mark>**7. Fix SQL syntax agent node**</mark>

This node is an agent that fixs invalid SQL from write_sql_agent_node

In [266]:
class FixSQLParser(BaseModel):
    query: str = Field(description="Fixed query")

In [267]:
SYSTEM_PROMPT3: str = """
You are an BigQuery SQL guru. Your task is to troubleshoot a BigQuery SQL query. As the user provides versions of the query and the errors returned by BigQuery,
return a new alternative SQL query that fixes the errors. It is important that the query still answers the original question.

<Guidelines>
- Join as minimal tables as possible.
- When joining tables ensure all join columns are the same data_type.
- Analyze the database and the table schema provided as parameters and undestand the relations (column and table relations).
- Use always SAFE_CAST. If performing a SAFE_CAST, use only Bigquery supported datatypes.
- Always SAFE_CAST and then use aggregate functions
- Don't include any comments in code.
- Remove ```sql and ``` from the output and generate the SQL in single line.
- Tables should be refered to using a fully qualified name with enclosed in ticks (`) e.g. `project_id.owner.table_name`.
- Use all the non-aggregated columns from the "SELECT" statement while framing "GROUP BY" block.
- Return syntactically and symantically correct SQL for BigQuery with proper relation mapping i.e project_id, owner, table and column relation.
- Use ONLY the column names (column_name) mentioned in Table Schema. DO NOT USE any other column names outside of this.
- Associate column_name mentioned in Table Schema only to the table_name specified under Table Schema.
- Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
- Table names are case sensitive. DO NOT uppercase or lowercase the table names.
- Always enclose subqueries and union queries in brackets.
</Guidelines>

Your output should be in this format. 
{format_instructions}
"""

HUMAN_PROMPT3: str = """
<User's question>
{rewrite_question}
</User's question>

<Error Message>
{error_message}
</Error Message>

<Previous Query>
{query}
</Previous Query>

<Examples>
{related_query}
</Examples>

<Table Schema>
table : `cbm-cgs-uiim-prd.mro_demo.mrodata`
</Table Schema>

<Columns Schema>
{related_columns}
</Columns Schema>
"""

PROMPT_TEMPLATE3: list[tuple] = [
    ('system', SYSTEM_PROMPT3),
    ('human', HUMAN_PROMPT3)
]

In [None]:
def fix_sql_agent_node(state: AgentState) -> AgentState:
    """Fix SQL by an agent"""
    print("\033[92m--- Fix SQL ---\033[00m")
    print(state["error_message"])
    chain = init_agent_model(prompt=PROMPT_TEMPLATE3, model_config=MODEL_CONFIG, parser=FixSQLParser)
    
    response = chain.invoke({
        "rewrite_question": state["rewrite_question"], # If there is no "formatted_history" then use an empty string
        "related_query": state["related_query"],
        "related_columns": state["related_columns"],
        "error_message": state["error_message"],
        "query": state["query"]
    })
    
    # Set fixed query (fixed sql query) state to be the response from an agent
    return {
        "query": response["query"]
    }

#### <mark>**8. Response agent node**</mark>

This node is an agent that response the query result to the user as a natural language 

In [269]:
SYSTEM_PROMPT4: str = """
# [ROLE] 
You are Data Assistant that helps to answer users' questions on their data within their databases.
Provide a natural sounding response to the user question using only the data provided to you.

# [PERSONA]
**Identity**: CiMie, warm yet professional female AI assistant
**Reply language:** Th
**Format:** Markdown
**Task:** Answer questions based on the provided context.
**Tone mix:** Conversational / Professional / Friendly / Supportive

# [RULES]
- Answer only based on the provided context.
- If you give links, format them as Markdown: [filename](url).
- Be concise and to the point.
- Use bullet points or numbered lists for clarity when needed.
- Use markdown formatting for better readability.
- Always be polite and respectful.
- Do *NOT* tell the user about sql query.
"""

HUMAN_PROMPT4: str = """
The user has provided the following question in natural language: 
{rewrite_question}

<Column Detail>
{related_columns}
</Column Detail>

The system has returned the following result after running the SQL query: 
{result_df}
"""

PROMPT_TEMPLATE4: list[tuple] = [
    ('system', SYSTEM_PROMPT4),
    ('human', HUMAN_PROMPT4)
]

In [None]:
def response_agent_node(state: AgentState) -> AgentState:
    """Response SQL result by an agent in natural language"""
    start_time = state.get("text2sql_time", None)
    print("\033[92m--- CiMie Response ---\033[00m")
    
    chain = init_agent_model(prompt=PROMPT_TEMPLATE4, model_config=MODEL_CONFIG, parser=None)
    
    response = chain.invoke({
        "rewrite_question": state["rewrite_question"], # If there is no "formatted_history" then use an empty string
        "related_columns": state["related_columns"],
        "result_df": state["result_df"].to_markdown(index=False),
    })
    formatted_history= state.get('formatted_history', [])
    formatted_history.append(
        [
            ('rewrite_question', state['rewrite_question']),
            ('related_columns', state['related_columns']),
            ('result_df', state['result_df'].to_markdown(index=False)),
            ('result_response', response)
        ]
    )
    
    print("\n" + response)
    
    if start_time:
        elapsed_time = time.time() - start_time
        print(f"\n\033[96m ===========> ⏱️ Text2sql operation time taken: {elapsed_time:.3f} seconds <===========\n\033[00m")
    else:
        print("Start time not found")
        
        
    return {
        "result_response" : response,
        "formatted_history" : formatted_history
    }

#### <mark>**9. Analyze chart type agent node**</mark>

This node is an agent that analyzes chart type based on the query result for visualization

In [271]:
class AnalyzeChartType(BaseModel):
    chart_suggestions: list[str] = Field(description="Chart type suggestions from ai analysis")

In [272]:
SYSTEM_PROMPT5: str = """
You are expert in generating visualizations.

<Best Practices>
Some commonly used charts and when do use them:-
- Text or Score card is best for showing single value answer
- Table is best for Showing data in a tabular format.
- Bullet Chart is best for Showing individual values across categories.
- Bar Chart is best for Comparing individual values across categories, especially with many categories or long labels.
- Column Chart is best for Comparing individual values across categories, best for smaller datasets.
- Line Chart is best for Showing trends over time or continuous data sets with many data points.
- Area Chart is best for Emphasizing cumulative totals over time, or the magnitude of change across multiple categories.
- Pie Chart is best for Show proportions of a whole, but only for a few categories (ideally less than 6).
- Scatter Plot is best for Investigating relationships or correlations between two variables.
- Bubble Chart is best for Comparing and showing relationships between three variables.
- Histogram is best for Displaying the distribution and frequency of continuous data.
- Map Chart is best for Visualizing data with a geographic dimension (countries, states, regions, etc.).
- Gantt Chart	is best for Managing projects, visualizing timelines, and task dependencies.
- Heatmap is best for	Showing the density of data points across two dimensions, highlighting areas of concentration.
<Best Practices>

<Guidelines>
-Do not add any explanation to the response. Only stick to format Chart-1, Chart-2
-Do not enclose the response with js or javascript or ```
</Guidelines>

Your output should be in this format (a valid JSON format with two elements chart_1 and chart_2 as below). 
{format_instructions}
"""

HUMAN_PROMPT5: str = """
Below is the Question and corresponding SQL Generated, suggest best two of the chart types
Question : {rewrite_question}
Corresponding SQL : {result_df}
"""

PROMPT_TEMPLATE5: list[tuple] = [
    ('system', SYSTEM_PROMPT5),
    ('human', HUMAN_PROMPT5)
]

In [None]:
def analyze_chart_type_agent_node(state: AgentState) -> AgentState:
    """Analyze chart types from SQL result by an agent"""
    state["visualization_time"] = time.time()
    print("\033[92m--- Analyze Chart Types ---\033[00m")
    
    chain = init_agent_model(prompt=PROMPT_TEMPLATE5, model_config=MODEL_CONFIG, parser=AnalyzeChartType)
    
    response = chain.invoke({
        "rewrite_question": state["rewrite_question"], # If there is no "formatted_history" then use an empty string
        "result_df": state["result_df"].to_markdown(index=False),
    })
    
    state["chart_suggestions"] = response["chart_suggestions"]
    return state

#### <mark>**10. Visualize chart type agent node**</mark>

This node is an agent that generates charts based on the result from "analyze_chart_type_agent_node" for visualization

In [274]:
class VisualizeChart(BaseModel):
    chart_code: str = Field(description="Javascript chart code")

In [275]:
SYSTEM_PROMPT6: str = """
You are expert in generating visualizations.
                
Guidelines:
-Do not add any explanation to the response.
-Do not enclose the response with js or javascript or ```

You are asked to generate a visualization for the following question:
{rewrite_question}

The SQL generated for the question is:
{query}

The results of the sql which should be used to generate the visualization are in json format as follows:
{result_df}

Needed chart type is  : {chart_suggestions}

Guidelines:

- Generate js code for {chart_suggestions} for the visualization using google charts and its possible data column. You do not need to use all the columns if not possible.
- The generated js code should be able to be just evaluated as javascript so do not add any extra text to it.
- ONLY USE the template below and STRICTLY USE ELEMENT ID chart-{chart_div} TO CREATE THE CHART
- drawChart fuction name must STRICTLY USE follow this name drawChart{chart_div}

google.charts.load('current', <add packages>);
google.charts.setOnLoadCallback(drawChart);
drawchart function 
    var data = <Datatable>
    with options
Title=<<Give appropiate title>>
width=600,
height=300,
hAxis.textStyle.fontSize=5
vAxis.textStyle.fontSize=5
legend.textStyle.fontSize=10

other necessary options for the chart type
    var chart = new google.charts.<chart name>(document.getElementById('chart-{chart_div}'));
    chart.draw()

Your output should be in this format (a valid JSON format with two chart codes as below). 
{format_instructions}
"""

HUMAN_PROMPT6: str = """
Example Response: 

google.charts.load('current', {{packages: ['corechart']}});
google.charts.setOnLoadCallback(drawChart);
    function drawChart() 
{{var data = google.visualization.arrayToDataTable([['Product SKU', 'Total Ordered Items'],
    ['GGOEGOAQ012899', 456],   ['GGOEGDHC074099', 334], 
    ['GGOEGOCB017499', 319],    ['GGOEGOCC077999', 290], 
    ['GGOEGFYQ016599', 253],  ]); 
    
var options =
    {{ title: 'Top 5 Product SKUs Ordered',  
    width: 600,   height: 300,    hAxis: {{     
    textStyle: {{       fontSize: 12    }} }},  
    vAxis: {{     textStyle: {{      fontSize: 12     }}    }},
    legend: {{    textStyle: {{       fontSize: 12      }}   }},  
    bar: {{      groupWidth: '50%'    }}  }};
    var chart = new google.visualization.BarChart(document.getElementById('chart-{chart_div}')); 
    chart.draw(data, options);}}
"""

PROMPT_TEMPLATE6: list[tuple] = [
    ("system", SYSTEM_PROMPT6),
    ("human", HUMAN_PROMPT6)
]

In [None]:
def visualize_chart_agent_node(state: AgentState) -> AgentState:
    """Analyze chart types from SQL result by an agent"""
    print(f"\033[92m--- Visualize Chart {state["chart_couter"] + 1} ---\033[00m")
    
    chain = init_agent_model(prompt=PROMPT_TEMPLATE6, model_config=MODEL_CONFIG, parser=VisualizeChart)
    
    response = chain.invoke({
        "rewrite_question": state["rewrite_question"], # If there is no "formatted_history" then use an empty string
        "query": state["query"],
        "result_df": state["result_df"].to_markdown(index=False),
        "chart_suggestions": str(state["chart_suggestions"][state["chart_couter"]]),
        "chart_div": state["chart_couter"] + 1
    })
    
    # update states
    state["chart_js_codes"].append(response["chart_code"])
    state["chart_couter"] += 1
    
    return state

#### <mark>**11. Error SQL**</mark>

This node is an agent that generates error msg. For handling not SELECT query.

In [277]:
SYSTEM_PROMPT7: str = """
# [ROLE] 
You are warm yet professional female AI assistant.
Provide a natural sounding response to the user question.

# [PERSONA]
**Identity**: CiMie, warm yet professional female AI assistant
**Reply language:** Th
**Format:** Markdown
**Task:** Tell user that this sql method is not allowed.
"""

In [None]:
def error_response_agent_node(state: AgentState) -> AgentState:
    """Response Error response by an agent in natural language"""
    print("\033[92m--- Error Response ---\033[00m")
    
    chain = init_agent_model(prompt=SYSTEM_PROMPT7, model_config=MODEL_CONFIG, parser=None)
    
    response = chain.invoke({
        "input": 'The system from sql has returned "not_related_msg", which means this sql method not allowed. Please answer me that this method not allowed'
    })
    
    print("\n" + response)
    
    return {
        "result_response" : response,
    }

### Edges

In [279]:
def check_maximum_debug(state: AgentState) -> AgentState:
    if state.get("error_message"): 
        if "not_related_msg" in state.get('error_message'):
            return "NOT RELATED"
        if state.get("debugging_count") > MAXIMUM_DEBUG:
            raise Exception('MAXIMUM_DEBUG Excess')
        return "ERROR"
    return "OK"

In [280]:
def check_count_chart(state: AgentState) -> AgentState:
    if state["chart_couter"] < 2:
        return "GENARATE MORE CHART"
    return "FINISH"

### **Graph**

In [281]:
from langgraph.graph import StateGraph, START, END

**Flow 1**: generate analyzer and visualizer before CiMie response

In [None]:
# graph = StateGraph(AgentState)

# graph.add_node("SETUP", setup)
# graph.add_node("REWRITE QUESTION", rewrite_question_agent_node)
# graph.add_node("SEARCH RELATED QUERY", search_query_node)
# graph.add_node("SEARCH DATABASE COLUMN", search_column_description_node)
# graph.add_node("WRITE SQL", write_sql_agent_node)
# graph.add_node("CHECK SQL", check_sql_node)
# graph.add_node("FIX SQL", fix_sql_agent_node)
# graph.add_node("RUN SQL", run_sql_node)
# graph.add_node("ANALYZE CHART TYPE", analyze_chart_type_agent_node)
# graph.add_node("VISUALIZE CHART", visualize_chart_agent_node)
# graph.add_node("CIMIE RESPONSE", response_agent_node)
# graph.add_node("ERROR RESPONSE", error_response_agent_node)

# graph.add_edge(START, "SETUP")
# graph.add_edge("SETUP", "REWRITE QUESTION")
# graph.add_edge("REWRITE QUESTION", "SEARCH RELATED QUERY")
# graph.add_edge("SEARCH RELATED QUERY", "SEARCH DATABASE COLUMN")
# graph.add_edge("SEARCH DATABASE COLUMN", "WRITE SQL")
# graph.add_edge("WRITE SQL", "CHECK SQL")

# graph.add_conditional_edges(
#     "CHECK SQL",
#     check_maximum_debug,
#     {
#         "NOT RELATED": "ERROR RESPONSE",
#         "ERROR" : "FIX SQL",
#         "OK" : "RUN SQL"
#     }
# )

# graph.add_edge("FIX SQL", "CHECK SQL")
# graph.add_edge("RUN SQL", "ANALYZE CHART TYPE")
# # graph.add_edge("ANALYZE CHART TYPE", "CIMIE RESPONSE")
# graph.add_edge("ANALYZE CHART TYPE", "VISUALIZE CHART")

# graph.add_conditional_edges(
#     "VISUALIZE CHART",
#     check_count_chart,
#     {
#         "GENARATE MORE CHART": "VISUALIZE CHART",
#         "FINISH": "CIMIE RESPONSE"
#     }
# )

# graph.add_edge("CIMIE RESPONSE", END)
# graph.add_edge("ERROR RESPONSE", END)

# app = graph.compile()

**Flow 2**: generate CiMie response before analyzer and visualizer 

In [298]:
graph = StateGraph(AgentState)

graph.add_node("SETUP", setup)
graph.add_node("REWRITE QUESTION", rewrite_question_agent_node)
graph.add_node("SEARCH RELATED QUERY", search_query_node)
graph.add_node("SEARCH DATABASE COLUMN", search_column_description_node)
graph.add_node("WRITE SQL", write_sql_agent_node)
graph.add_node("CHECK SQL", check_sql_node)
graph.add_node("FIX SQL", fix_sql_agent_node)
graph.add_node("RUN SQL", run_sql_node)
graph.add_node("ANALYZE CHART TYPE", analyze_chart_type_agent_node)
graph.add_node("VISUALIZE CHART", visualize_chart_agent_node)
graph.add_node("CIMIE RESPONSE", response_agent_node)
graph.add_node("ERROR RESPONSE", error_response_agent_node)

graph.add_edge(START, "SETUP")
graph.add_edge("SETUP", "REWRITE QUESTION")
graph.add_edge("REWRITE QUESTION", "SEARCH RELATED QUERY")
graph.add_edge("SEARCH RELATED QUERY", "SEARCH DATABASE COLUMN")
graph.add_edge("SEARCH DATABASE COLUMN", "WRITE SQL")
graph.add_edge("WRITE SQL", "CHECK SQL")

graph.add_conditional_edges(
    "CHECK SQL",
    check_maximum_debug,
    {
        "NOT RELATED": "ERROR RESPONSE",
        "ERROR" : "FIX SQL",
        "OK" : "RUN SQL"
    }
)

graph.add_edge("FIX SQL", "CHECK SQL")
graph.add_edge("RUN SQL", "CIMIE RESPONSE")
graph.add_edge("CIMIE RESPONSE", "ANALYZE CHART TYPE")

graph.add_edge("ANALYZE CHART TYPE", "VISUALIZE CHART")

graph.add_conditional_edges(
    "VISUALIZE CHART",
    check_count_chart,
    {
        "GENARATE MORE CHART": "VISUALIZE CHART",
        "FINISH": END
    }
)

graph.add_edge("ERROR RESPONSE", END)

app = graph.compile()

Graph Visualization

In [284]:
# from IPython.display import Image, display
# flow = app.get_graph().draw_mermaid_png()
# display(Image(flow))

# Save image
# with open("../flow2.png", "wb") as f:
#     f.write(flow)

#### **Testing System**

Positive Case

In [296]:
result = app.invoke({'original_question' : 'Give me last 3 reports'})

chart_suggestions = result.get("chart_suggestions")
chart_js_codes = result.get("chart_js_codes")

charts: dict[str, str] = {}
if chart_suggestions and chart_js_codes:
    print(f"\nChart Type Suggesstion: {chart_suggestions}")
    for idx, chart_code in enumerate(chart_js_codes):
        charts[f"chart-{idx+1}"] = str(chart_code)

[92m --- Rewrite question ---[00m
[92m --- Search query ---[00m
[92m --- Search column description ---[00m
[92m --- Write SQL ---[00m
SELECT Files FROM `cbm-cgs-uiim-prd.mro_demo.mrodata` ORDER BY SAFE_CAST(PARSE_DATE('%d-%b-%y', Dates) AS DATE) DESC LIMIT 3
[92m --- Check SQL ---[00m
[92m --- Fix SQL ---[00m
400 Unrecognized name: Dates; Did you mean Datess? at [1:97]; reason: invalidQuery, location: query, message: Unrecognized name: Dates; Did you mean Datess? at [1:97]

Location: asia-southeast1
Job ID: 00faf13d-9a64-4087-a739-d8984215921d

[92m --- Check SQL ---[00m
[92m --- Run SQL ---[00m
[92m --- Analyze Chart Type ---[00m
[92m --- Visualize Chart 1 ---[00m
[92m --- Visualize Chart 2 ---[00m
[92m --- CiMie Response ---[00m

สวัสดีค่ะ! CiMie ได้รวบรวมรายงาน 3 ฉบับล่าสุดมาให้คุณแล้วนะคะ:

*   [86449.pdf](https://storage.googleapis.com/cbm-cgs-acb-km-assets/MRO-demo/86449.pdf)
*   [86396.pdf](https://storage.googleapis.com/cbm-cgs-acb-km-assets/MRO-demo/863

In [293]:
result = app.invoke({'original_question' : 'Give me last 3 reports'})

chart_suggestions = result.get("chart_suggestions")
chart_js_codes = result.get("chart_js_codes")

charts: dict[str, str] = {}
if chart_suggestions and chart_js_codes:
    print(f"\nChart Type Suggesstion: {chart_suggestions}")
    for idx, chart_code in enumerate(chart_js_codes):
        charts[f"chart-{idx+1}"] = str(chart_code)
        
if result["visualization_time"]:
    elapsed_time = time.time() - result["visualization_time"]
    print(f"\n\033[96m ===========> ⏱️ Visualization operation time taken: {elapsed_time:.3f} seconds <===========\033[00m")

[92m --- Rewrite question ---[00m
[92m --- Search query ---[00m
[92m --- Search column description ---[00m
[92m --- Write SQL ---[00m
SELECT Files FROM `cbm-cgs-uiim-prd.mro_demo.mrodata` ORDER BY SAFE_CAST(PARSE_DATE('%d-%b-%y', Dates) AS DATE) DESC LIMIT 3
[92m --- Check SQL ---[00m
[92m --- Fix SQL ---[00m
400 Unrecognized name: Dates; Did you mean Datess? at [1:97]; reason: invalidQuery, location: query, message: Unrecognized name: Dates; Did you mean Datess? at [1:97]

Location: asia-southeast1
Job ID: 6330e49b-9427-4898-ad0f-161f087174a2

[92m --- Check SQL ---[00m
[92m --- Run SQL ---[00m
[92m --- CiMie Response ---[00m

สวัสดีค่ะ! CiMie ได้รวบรวมรายงาน 3 ฉบับล่าสุดมาให้คุณแล้วนะคะ:

*   [86449.pdf](https://storage.googleapis.com/cbm-cgs-acb-km-assets/MRO-demo/86449.pdf)
*   [86396.pdf](https://storage.googleapis.com/cbm-cgs-acb-km-assets/MRO-demo/86396.pdf)
*   [86397.pdf](https://storage.googleapis.com/cbm-cgs-acb-km-assets/MRO-demo/86397.pdf)

หวังว่าจะเป็นป

In [300]:
result = app.invoke({'original_question' : 'Which are the top 5 machines with highest number of reports?'})

chart_suggestions = result.get("chart_suggestions")
chart_js_codes = result.get("chart_js_codes")

charts: dict[str, str] = {}
if chart_suggestions and chart_js_codes:
    print(f"\nChart Type Suggesstion: {chart_suggestions}")
    for idx, chart_code in enumerate(chart_js_codes):
        charts[f"chart-{idx+1}"] = str(chart_code)
        
if result["visualization_time"]:
    elapsed_time = time.time() - result["visualization_time"]
    print(f"\n\033[96m ===========> ⏱️ Visualization operation time taken: {elapsed_time:.3f} seconds <===========\033[00m")

[92m --- Rewrite question ---[00m
[92m --- Search query ---[00m
[92m --- Search column description ---[00m
[92m --- Write SQL ---[00m
SELECT SAFE_CAST(t1.MachineCode AS STRING) AS MachineCode, COUNT(SAFE_CAST(t1.MachineCode AS STRING)) AS report_count FROM `cbm-cgs-uiim-prd.mro_demo.mrodata` AS t1 GROUP BY SAFE_CAST(t1.MachineCode AS STRING) ORDER BY report_count DESC LIMIT 5
[92m --- Check SQL ---[00m
[92m --- Run SQL ---[00m
[92m --- CiMie Response ---[00m

สวัสดีค่ะ! จากข้อมูลที่ได้รับมา CiMie พบว่ามีเครื่องจักร 5 อันดับแรกที่มีจำนวนรายงานสูงสุด ดังนี้ค่ะ:

*   **W2B11**: 1 รายงาน
*   **W2A51**: 1 รายงาน
*   **W2A53**: 1 รายงาน
*   **W2A11**: 1 รายงาน
*   **W2C11**: 1 รายงาน

หากมีคำถามอื่น ๆ เพิ่มเติม ถาม CiMie ได้เลยนะคะ ยินดีช่วยเหลือค่ะ!

[00m
[92m --- Analyze Chart Type ---[00m
[92m --- Visualize Chart 1 ---[00m
[92m --- Visualize Chart 2 ---[00m

Chart Type Suggesstion: ['Column Chart', 'Bar Chart']



### **Visualization**

In [301]:
html_content = """
<!DOCTYPE html>
<html>
  <head>
    <title>Visualization Example</title>
    <script type="text/javascript" src="https://www.gstatic.com/charts/loader.js"></script>
  </head>
  <style>
    body {
      padding: 0;
      margin: 0;
      font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI",
        Roboto, Oxygen, Ubuntu, Cantarell, "Open Sans", "Helvetica Neue",
        sans-serif;
    }
    .container {
      width: 100%;
      height: 100vh;
      display: flex;
      justify-content: center;
      align-items: center;
    }
    .card {
      width: fit-content;
      height: fit-content;
      text-align: center;
      background: #282828;
      color: aliceblue;
      font-size: 30px;
      padding: 2rem;
      border-radius: 1rem;
    }
  </style>
  <body>
    <div class="container">
      <div class="card">
        <h2>Google Charts Demo</h2>
        <div id="chart-1"></div>
        <div id="chart-2"></div>
      </div>
    </div>
  </body>
  <script type="text/javascript">
"""
for key, code in charts.items():
    code_single_line = code.replace("\n", " ")
    html_content += f"{code_single_line}\n"

html_content += """
  </script>
</html>
"""

with open("../index2.html", "w", encoding="utf-8") as f:
    f.write(html_content)

print("HTML file generated: index.html")


HTML file generated: index.html


Negative Case

In [None]:
result = app.invoke({'original_question' : 'Can you ADD me the new report file of machine (W12C4) 1,000 records?'})
# result = app.invoke({'original_question' : "INSERT INTO employees (id, first_name, last_name, email, hire_date, salary) VALUES (1, 'Alice', 'Johnson', 'alice.johnson@example.com', '2025-01-15', 55000);"})

[92m --- Rewrite question ---[00m
[92m --- Search query ---[00m
[92m --- Search column description ---[00m
[92m --- Write SQL ---[00m
not_related_msg
[92m --- Check SQL ---[00m
[92m --- Error Response ---[00m

สวัสดีค่ะ CiMie ยินดีให้บริการนะคะ

CiMie ต้องขอแจ้งให้ทราบว่า การใช้คำสั่ง SQL โดยตรงนั้นไม่ได้รับอนุญาตในระบบของเราค่ะ เนื่องจากเหตุผลด้านความปลอดภัยของข้อมูลและเพื่อรักษาเสถียรภาพของระบบโดยรวมค่ะ

หากคุณมีข้อสงสัยหรือต้องการความช่วยเหลือในการดำเนินการใดๆ ที่เกี่ยวข้องกับข้อมูล รบกวนแจ้งรายละเอียดเพิ่มเติมให้ CiMie ทราบได้เลยนะคะ CiMie ยินดีให้คำแนะนำและช่วยเหลือในช่องทางที่เหมาะสมค่ะ

ขอบคุณสำหรับความเข้าใจนะคะ
