# 1️⃣ Configuration

In [None]:
import os
import pandas as pd
import sqlite3
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langchain_experimental.agents import create_csv_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [None]:
from dotenv import load_dotenv
load_dotenv()

os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

In [None]:
finance_file = "/Users/keshavsaraogi/Desktop/indorama/eureka-data/clean-csv/cleaned_finance_packaging.csv"
inventory_file = "/Users/keshavsaraogi/Desktop/indorama/eureka-data/clean-csv/cleaned_inventory_packaging.csv"
spend_file = "/Users/keshavsaraogi/Desktop/indorama/eureka-data/clean-csv/cleaned_spend_packaging.csv"
sales_file = "/Users/keshavsaraogi/Desktop/indorama/eureka-data/clean-csv/cleaned_sales_packaging.csv"

In [None]:
# Load CSV into pandas
finance_df = pd.read_csv(finance_file)
inventory_df = pd.read_csv(inventory_file)
spend_df = pd.read_csv(spend_file)
sales_df = pd.read_csv(sales_file)

In [None]:
def create_in_memory_db(df, table_name):
    conn = sqlite3.connect(":memory:", check_same_thread=False)
    df.to_sql(table_name, conn, index=False, if_exists="replace")
    return conn

In [None]:
finance_db = create_in_memory_db(finance_df, "finance")
inventory_db = create_in_memory_db(inventory_df, "inventory")
spend_db = create_in_memory_db(spend_df, "spend")
sales_db = create_in_memory_db(sales_df, "sales")

In [None]:
finance_engine = create_engine("sqlite://", creator=lambda: finance_db)
inventory_engine = create_engine("sqlite://", creator=lambda: inventory_db)
spend_engine = create_engine("sqlite://", creator=lambda: spend_db)
sales_engine = create_engine("sqlite://", creator=lambda: sales_db)

In [None]:
finance_sql_db = SQLDatabase(engine=finance_engine, sample_rows_in_table_info=5)
inventory_sql_db = SQLDatabase(engine=inventory_engine, sample_rows_in_table_info=5)
spend_sql_db = SQLDatabase(engine=spend_engine, sample_rows_in_table_info=5)
sales_sql_db = SQLDatabase(engine=sales_engine, sample_rows_in_table_info=5)

In [None]:
llm = ChatOpenAI(temperature=0.5, model="gpt-4o")

In [None]:
finance_toolkit = SQLDatabaseToolkit(db=finance_sql_db, llm=llm)
inventory_toolkit = SQLDatabaseToolkit(db=inventory_sql_db, llm=llm)
spend_toolkit = SQLDatabaseToolkit(db=spend_sql_db, llm=llm)
sales_toolkit = SQLDatabaseToolkit(db=sales_sql_db, llm=llm)

In [None]:
def get_column_names(filepath):
    df = pd.read_csv(filepath, nrows=1)
    return list(df.columns)

finance_columns = get_column_names(finance_file)
inventory_columns = get_column_names(inventory_file)
spend_columns = get_column_names(spend_file)
sales_columns = get_column_names(sales_file)

# 2️⃣ Tools & Agents

In [None]:

sql_agent_prompt_prefix = """
You are a SQL expert agent following ReAct reasoning.

- You must ALWAYS output Thought -> Action -> Observation -> Final Answer.
- DO NOT output meta commentary.
- After seeing Observation results (SQL query output), ALWAYS extract concrete values.
- ALWAYS summarize the result table to answer the user's original question directly.
- NEVER say "the query successfully identifies..." — always give actual values.
- DO NOT wrap SQL code in markdown formatting or backticks.
- ONLY output valid SQL without formatting.
- If column names contain spaces, enclose them in double quotes.
- The SQL dialect is SQLite.
- ALWAYS use the available tools (sql_db_query) to execute your queries.
- NEVER just write SQL queries.
- ALWAYS call the action sql_db_query with the query as input.
- You are allowed to chain multiple queries to answer the question.
- If you encounter repeated errors or cannot execute the SQL query, still follow the ReAct format.
- When unable to answer, output:
Thought: I am unable to answer.
Final Answer: Unable to retrieve the data due to internal error.
- Do not write freeform explanations.
- Never write paragraphs describing failure.
- NEVER output markdown formatting.
- NEVER output queries inside triple backticks or code fences.
- ONLY output raw SQL text.
- SQLite does not support '%q' for quarters.
- To compute quarter, use strftime('%m', "Date") and CASE WHEN statements.
- NEVER use '%q' inside strftime() queries.
"""


In [None]:
import os
import pandas as pd
import sqlite3
from sqlalchemy import create_engine
from langchain_openai import ChatOpenAI
from langchain_experimental.agents import create_csv_agent
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

In [None]:
finance_csv_agent = create_csv_agent(llm, finance_file, verbose=True, allow_dangerous_code=True)
inventory_csv_agent = create_csv_agent(llm, inventory_file, verbose=True, allow_dangerous_code=True)
spend_csv_agent = create_csv_agent(llm, spend_file, verbose=True, allow_dangerous_code=True)
sales_csv_agent = create_csv_agent(llm, sales_file, verbose=True, allow_dangerous_code=True)

In [None]:
finance_toolkit = SQLDatabaseToolkit(db=finance_sql_db, llm=llm)
inventory_toolkit = SQLDatabaseToolkit(db=inventory_sql_db, llm=llm)
spend_toolkit = SQLDatabaseToolkit(db=spend_sql_db, llm=llm)
sales_toolkit = SQLDatabaseToolkit(db=sales_sql_db, llm=llm)

In [None]:
finance_sql_agent = create_sql_agent(llm=llm, toolkit=finance_toolkit, verbose=True, max_iterations=40, max_execution_time=120, handle_parsing_errors=True, early_stopping_method="generate",prefix=sql_agent_prompt_prefix)
inventory_sql_agent = create_sql_agent(llm=llm, toolkit=inventory_toolkit, verbose=True, max_iterations=40, max_execution_time=120, handle_parsing_errors=True, early_stopping_method="generate", prefix=sql_agent_prompt_prefix)
spend_sql_agent = create_sql_agent(llm=llm, toolkit=spend_toolkit, verbose=True, max_iterations=40, max_execution_time=120, handle_parsing_errors=True, early_stopping_method="generate",prefix=sql_agent_prompt_prefix)
sales_sql_agent = create_sql_agent(llm=llm, toolkit=sales_toolkit, verbose=True, max_iterations=40, max_execution_time=120, handle_parsing_errors=True, early_stopping_method="generate",prefix=sql_agent_prompt_prefix)

# 3️⃣ Visualization Tool

In [None]:

# 🔹 Import Matplotlib
import matplotlib.pyplot as plt

# 🔹 Import Prompt & Parser for LLM-based classification
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# 🔹 Prompt Template
visualization_prompt = PromptTemplate.from_template("""
You are a visualization expert.
Given a user query and the result table columns, suggest the most appropriate visualization type.

User Query: "{query}"
Result Columns: {columns}

Available chart types: ["bar", "line", "scatter", "pie", "histogram", "none"]

Answer with ONLY the chart type (one word).
""")

# 🔹 LLM Chain
visualization_chain = visualization_prompt | llm | StrOutputParser()

# 🔹 Classifier Function
def classify_visualization_type(user_query, df_result):
    columns_str = ", ".join(df_result.columns)
    chart_type = visualization_chain.invoke({
        "query": user_query,
        "columns": columns_str
    }).strip().lower()
    return chart_type

# 🔹 Visualization Tool Function
def visualization_tool(user_query, df_result):
    chart_type = classify_visualization_type(user_query, df_result)
    print(f"🔍 LLM-chosen chart type: {chart_type}")

    x_col = df_result.columns[0]
    y_col = df_result.columns[1] if len(df_result.columns) > 1 else None

    fig, ax = plt.subplots(figsize=(8, 6))

    if chart_type == "bar" and y_col:
        df_result.plot(kind='bar', x=x_col, y=y_col, ax=ax)
    elif chart_type == "line" and y_col:
        df_result.plot(kind='line', x=x_col, y=y_col, ax=ax)
    elif chart_type == "scatter" and y_col:
        df_result.plot(kind='scatter', x=x_col, y=y_col, ax=ax)
    elif chart_type == "pie":
        df_result.set_index(x_col).plot(kind='pie', y=df_result.columns[1], ax=ax, autopct='%1.1f%%')
    elif chart_type == "histogram" and y_col:
        df_result[y_col].plot(kind='hist', ax=ax, bins=10)
    else:
        print("⚠️ No suitable chart type detected or 'none' returned by LLM.")

    if chart_type != "none":
        ax.set_title(f"LLM: {chart_type} chart")
        plt.tight_layout()
        plt.show()


# 4️⃣ Example Controller / Example Calls

In [None]:
def route_dataset(user_query: str) -> str:
    chain = dataset_routing_prompt | llm | StrOutputParser()

    prompt_input = {
        "query": user_query,
        "finance_cols": ", ".join(finance_columns),
        "inventory_cols": ", ".join(inventory_columns),
        "spend_cols": ", ".join(spend_columns),
        "sales_cols": ", ".join(sales_columns),
    }

    result = chain.invoke(prompt_input).strip().lower()

    if result not in {"finance", "inventory", "spend", "sales"}:
        print(f"⚠️ Unexpected dataset output: {result} — defaulting to 'finance'")
        return "finance"
    
    return result

In [None]:
agent_type_prompt = PromptTemplate.from_template("""
You are a routing assistant that decides how to process data queries.
Given the following user query:
"{query}"
Decide whether it should be handled using SQL (for aggregation, filtering, grouping, numeric analysis),
or using CSV (for visualization, listing, non-aggregated exploration).
Read dates as DD-MM-YYYY.
Respond only with: sql or csv.
""")

def route_agent_type(query: str) -> str:
    chain = agent_type_prompt | llm | StrOutputParser()
    result = chain.invoke({"query": query}).strip().lower()
    if result not in {"sql", "csv"}:
        return "csv"
    return result

In [None]:
def master_agent(user_query):
    dataset = route_dataset(user_query)
    agent_type = route_agent_type(user_query)

    agents = {
        "finance": {"sql": finance_sql_agent, "csv": finance_csv_agent},
        "inventory": {"sql": inventory_sql_agent, "csv": inventory_csv_agent},
        "spend": {"sql": spend_sql_agent, "csv": spend_csv_agent},
        "sales": {"sql": sales_sql_agent, "csv": sales_csv_agent},
    }

    # Inject grounding for spend queries
    if dataset == "spend" and agent_type == "sql":
        user_query = f"{user_query}\n\n{PACKAGING_KNOWLEDGE}"

    agent = agents[dataset][agent_type]
    return agent.invoke(user_query)

# 5️⃣ Improved SQL Extractor + Example Usage

In [None]:
import matplotlib.ticker as mticker

def visualization_tool(user_query, df_result):
    chart_type = classify_visualization_type(user_query, df_result)
    print(f"🔍 LLM-chosen chart type: {chart_type}")

    x_col = df_result.columns[0]
    y_col = df_result.columns[1] if len(df_result.columns) > 1 else None

    fig, ax = plt.subplots(figsize=(12, 8))

    if chart_type == "bar" and y_col:
        df_result.plot(kind='bar', x=x_col, y=y_col, ax=ax)
    elif chart_type == "line" and y_col:
        df_result.plot(kind='line', x=x_col, y=y_col, ax=ax)
    elif chart_type == "scatter" and y_col:
        df_result.plot(kind='scatter', x=x_col, y=y_col, ax=ax)
    elif chart_type == "pie":
        # 🔥 Keep only positive values
        df_result = df_result[df_result[y_col] > 0]

        # 🔥 Limit to top 15 slices
        if len(df_result) > 15:
            df_result = df_result.nlargest(15, y_col)

        # 🔥 Plot pie WITHOUT labels on slices
        df_result.set_index(x_col).plot(
            kind='pie',
            y=y_col,
            ax=ax,
            labels=None,  # no labels on slices
            autopct='%1.1f%%',
            legend=False
        )

        # 🔥 Place legend outside
        ax.legend(
            labels=df_result[x_col],
            loc='center left',
            bbox_to_anchor=(1.0, 0.5),
            title=x_col
        )

    elif chart_type == "histogram" and y_col:
        df_result[y_col].plot(kind='hist', ax=ax, bins=10)
    else:
        print("⚠️ No suitable chart type detected or 'none' returned by LLM.")

    if chart_type != "none":
        ax.set_title(f"LLM: {chart_type} chart")
        if chart_type not in ["pie"]:
            ax.get_yaxis().set_major_formatter(mticker.FuncFormatter(lambda x, p: format(int(x), ',')))
        plt.tight_layout()
        plt.show()


In [None]:
sales_columns = list(sales_df.columns)
columns_str = "\n".join([f'- "{col}"' for col in sales_columns])

# Build the dynamic prompt
sql_prompt_template = f"""
You are an expert data analyst.

Here is the database schema:

Table: sales
Columns: 
{columns_str}

DO NOT invent any new tables.
Only use columns from sales table.
IMPORTANT: SQLite does NOT support strftime('%q').  
To compute quarter, use CASE WHEN on strftime('%m', "Sales Invoice Date").

Given the following user question, generate a correct SQLite SQL query using this schema.
DO NOT include any explanations.
DO NOT wrap in ```sql block.
Just output raw SQL.

Question: {{question}}

SQL Query:
"""

In [None]:
sql_prompt = PromptTemplate.from_template(sql_prompt_template)

In [None]:
sql_chain = sql_prompt | llm | StrOutputParser()

## TESTING

In [None]:
sql_query = sql_chain.invoke({"question": "Show me monthly sales trend for Petform (Thailand) Ltd for 2024"})
print(sql_query)

In [None]:
df_result = pd.read_sql(sql_query, sales_engine)

In [None]:
visualization_tool("Show me monthly sales trend for Petform (Thailand) Ltd for 2024", df_result)

In [None]:
sql_query = sql_chain.invoke({"question": "What is the total sales generated for Petform (Thailand) Ltd in 2024 for each quarter?"})
print(sql_query)

In [None]:
df_result = pd.read_sql(sql_query, sales_engine)

In [None]:
visualization_tool("What is the total sales generated for Petform (Thailand) Ltd in 2024 for each quarter?", df_result)

## SCATTERPLOT

In [None]:
question_5 = "Show the relationship between Quantity MT and Invoice Net value."
sql_query = sql_chain.invoke({"question": question_5})
print(sql_query)

In [None]:
df_result = pd.read_sql(sql_query, sales_engine)

In [None]:
visualization_tool(question_5, df_result)

## HISTOGRAM

In [None]:
question_4 = "Show the distribution of invoice net value for all sales in 2024 for each quarter"
sql_query = sql_chain.invoke({"question": question_4})
print(sql_query)

In [None]:
df_result = pd.read_sql(sql_query, sales_engine)
visualization_tool(question_4, df_result)

## Pie Chart


In [None]:
question_3 = "What is the sum of ending balance of SEVEN UP BOTTLING CO PLC for each month in 2024?"
sql_query = sql_chain.invoke({"question": question_3})
print(sql_query)

In [None]:
df_result = pd.read_sql(sql_query, sales_engine)
visualization_tool(question_3, df_result)

## Multiple Files Queries