# Importing Libraries


In [2]:
import os
import psycopg2
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
import pandas as pd
from langchain_community.utilities.sql_database import SQLDatabase
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv

## Test the Connection

In [4]:
import psycopg2

# Replace with your real values
host = os.getenv("HOST")        # The public IP of your Cloud SQL instance
port = os.getenv("PORT")                 # Default Postgres port
database = os.getenv("DATABASE")    # Your database name
user = os.getenv("USER")           # Your database username
password = os.getenv("PASSWORD") # Your database user's password

# Set sslmode to 'require' for encrypted connection
try:
    conn = psycopg2.connect(
        host=host,
        port=port,
        dbname=database,
        user=user,
        password=password,
        sslmode='require'
    )
    print("Connection successful!")
    # Do your SQL work here...
    conn.close()
except Exception as e:
    print("Connection failed:", e)


Connection successful!


## Creation of the schema in db

In [1]:
from sqlalchemy import create_engine
from schema import Base, get_tables_in_order, create_rls_policy  # import from your schema.py file

# Your database connection URL
DATABASE_URL = "postgresql+psycopg2://postgres:ashish6677@34.30.63.17:5432/postgres"

engine = create_engine(DATABASE_URL)

def create_all_tables():
    Base.metadata.bind = engine
    tables = get_tables_in_order()  # Method imported from schema.py
    for table_class in tables:
        table_class.__table__.create(bind=engine, checkfirst=True)

    # Apply RLS policies
    with engine.connect() as conn:
        for table_class in tables:
            if 'company_id' in table_class.__table__.c:
                conn.execute(create_rls_policy(table_class.__tablename__))

if __name__ == "__main__":
    create_all_tables()
    print("All tables created with RLS policies applied.")



KeyboardInterrupt



In [2]:
import nest_asyncio
nest_asyncio.apply()
from langgraph.prebuilt import create_react_agent
# TODO(developer): replace this with another import if needed
from langchain_google_vertexai import ChatVertexAI
# from langchain_google_genai import ChatGoogleGenerativeAI
# from langchain_anthropic import ChatAnthropi
from langgraph.checkpoint.memory import MemorySaver
from toolbox_langchain import ToolboxClient
prompt = """
  You're a SQL expert, you know how convert natural language to SQL,  
"""

queries = ["What are the informations I can get from the tool ?",
           "what kind of data is present in the employees table ? "
]

async def run_application():
    # TODO(developer): replace this with another model if needed
    model = ChatVertexAI(model_name="gemini-2.0-flash-001", project="gen-lang-client-0571342867")
    # model = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001")
    # model = ChatAnthropic(model="claude-3-5-sonnet-20240620")

    # Load the tools from the Toolbox server
    async with ToolboxClient("http://127.0.0.1:5000") as client:
        tools = await client.aload_toolset()

        agent = create_react_agent(model, tools, checkpointer=MemorySaver())

        config = {"configurable": {"thread_id": "thread-1"}}
        for query in queries:
            inputs = {"messages": [("user", prompt + query)]}
            response = agent.invoke(inputs, stream_mode="values", config=config)
            print(response["messages"][-1].content)

await run_application()


You can use this tool to:

1.  **Retrieve database schema information:**  You can get information about the tables and columns in the database.
2.  **Execute SQL queries:** You can run SQL queries against the database.

The `employees` table contains the following columns:

*   `employee_id`
*   `company_id`
*   `department_id`
*   `employee_code`
*   `name`
*   `pan`
*   `pf_number`
*   `esi_number`
*   `uan`
*   `doj` (Date of Joining)
*   `dol` (Date of Leaving)
*   `salary`
*   `email`
*   `phone`
*   `address`
*   `city`
*   `state`
*   `country`
*   `status`
*   `created_at`

This table appears to store information about employees, including their personal details, employment details, contact information, and employment status.


In [12]:
## Load the metadata on Redis

## Load the metadata on Redis

In [17]:
import json
import redis.asyncio as redis
from toolbox_langchain import ToolboxClient

async def fetch_schema_from_toolbox_and_save(redis_url, toolbox_url="http://127.0.0.1:5000", toolset_name="my-toolset"):
    # Connect to Redis using redis.asyncio
    redis_client = redis.from_url(redis_url)
    
    # Connect to Toolbox client
    async with ToolboxClient(toolbox_url) as client:
        # Load the toolset
        tools = await client.aload_toolset(toolset_name)
        
        # Find 'get-schema' tool
        get_schema_tool = next((t for t in tools if t.name == "get-schema"), None)
        print(get_schema_tool)
        if get_schema_tool is None:
            raise ValueError("get-schema tool not found in toolset")
        
        # Call get-schema tool asynchronously
        schema_output = get_schema_tool.invoke({})
        print(schema_output)
        # Serialize and store in Redis
        metadata_json = schema_output

        await redis_client.set("db_metadata", metadata_json)
    
    # Close Redis connection
    await redis_client.close()


redis_url = "redis://localhost:6379"
await fetch_schema_from_toolbox_and_save(redis_url)

name='get-schema' description='Retrieve database schema information including tables and columns.' args_schema=<class 'toolbox_core.utils.get-schema'>
[{"column_name":"account_id","table_name":"accounts"},{"column_name":"company_id","table_name":"accounts"},{"column_name":"code","table_name":"accounts"},{"column_name":"name","table_name":"accounts"},{"column_name":"type","table_name":"accounts"},{"column_name":"parent_account_id","table_name":"accounts"},{"column_name":"created_at","table_name":"accounts"},{"column_name":"is_active","table_name":"accounts"},{"column_name":"address_id","table_name":"addresses"},{"column_name":"company_id","table_name":"addresses"},{"column_name":"address_line1","table_name":"addresses"},{"column_name":"address_line2","table_name":"addresses"},{"column_name":"city","table_name":"addresses"},{"column_name":"state","table_name":"addresses"},{"column_name":"country","table_name":"addresses"},{"column_name":"pin_code","table_name":"addresses"},{"column_name"

  await redis_client.close()


## Check Redis Data 

In [19]:
import asyncio
import redis.asyncio as redis

async def check_redis_data(redis_url="redis://localhost:6379", key="db_metadata"):
    redis_client = redis.from_url(redis_url)
    value = await redis_client.get(key)
    if value:
        print(value.decode('utf-8'))  # assuming data is stored as a utf-8 JSON string
    else:
        print(f"No data found for key '{key}'")
    await redis_client.close()

await check_redis_data()


[{"column_name":"account_id","table_name":"accounts"},{"column_name":"company_id","table_name":"accounts"},{"column_name":"code","table_name":"accounts"},{"column_name":"name","table_name":"accounts"},{"column_name":"type","table_name":"accounts"},{"column_name":"parent_account_id","table_name":"accounts"},{"column_name":"created_at","table_name":"accounts"},{"column_name":"is_active","table_name":"accounts"},{"column_name":"address_id","table_name":"addresses"},{"column_name":"company_id","table_name":"addresses"},{"column_name":"address_line1","table_name":"addresses"},{"column_name":"address_line2","table_name":"addresses"},{"column_name":"city","table_name":"addresses"},{"column_name":"state","table_name":"addresses"},{"column_name":"country","table_name":"addresses"},{"column_name":"pin_code","table_name":"addresses"},{"column_name":"latitude","table_name":"addresses"},{"column_name":"longitude","table_name":"addresses"},{"column_name":"created_at","table_name":"addresses"},{"colu

  await redis_client.close()


# Agentic Implementation

## v2

In [1]:
import asyncio
import json
import sqlparse
import redis.asyncio as redis
from toolbox_langchain import ToolboxClient
from langchain_google_vertexai import ChatVertexAI
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.memory.chat_message_histories import RedisChatMessageHistory

In [2]:
from dotenv import load_dotenv
import os
load_dotenv('.env')
open_ai_key = os.getenv("OPEN_AI_KEY")

In [4]:


# Persistent conversation memory in Redis
history = RedisChatMessageHistory(
    url="redis://localhost:6379",
    session_id="user_001"
)


def extract_sql_from_response(response: str) -> str:
    import re
    # Extract content between ```sql and ```
    pattern = r"```sql\s*(.*?)```"
    match = re.search(pattern, response, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    else:
        # fallback: return the whole response if no fences found
        return response.strip()

# Load schema from Redis
async def load_schema(redis_url: str):
    redis_client = redis.from_url(redis_url)
    schema_json = await redis_client.get("db_metadata")
    await redis_client.close()
    return json.loads(schema_json.decode("utf-8")) if schema_json else []

# Agent 1: NL to SQL with follow-up support
# Modified Agent 1
async def agent1_nl_to_sql(user_question, schema_info, llm):
    last_sql = None
    for msg in reversed(history.messages):
        if "[Agent 1 Generated SQL]" in msg.content:
            last_sql = msg.content.replace("[Agent 1 Generated SQL]", "").strip()
            break

    schema_desc = ""
    for table in {col['table_name'] for col in schema_info}:
        cols = [c['column_name'] for c in schema_info if c['table_name'] == table]
        schema_desc += f"Table {table} (Columns: {', '.join(cols)})\n"

    template_text = (
        "Conversation so far:\n{history}\n\n"
        "You are an intelligent AI agent.\n"
        "If the user request is related to database query → generate ONLY a SQL (Postgres) wrapped inside ``````.\n"
        "If the user request is normal conversation (not database related) → just reply in plain text.\n"
    )
    if last_sql:
        template_text += (
            "Last SQL was:\n{last_sql}\n"
            "User’s follow-up:\n{question}\n"
        )
    else:
        template_text += "User’s question:\n{question}\n"

    prompt_template = PromptTemplate(
        input_variables=["history", "schema", "question", "last_sql"],
        template=template_text,
    )

    prompt = prompt_template.format(
        history="\n".join([f"{m.type}: {m.content}" for m in history.messages]),
        schema=schema_desc,
        question=user_question,
        last_sql=last_sql or ""
    )

    resp = llm.invoke(prompt)
    response_content = resp.content.strip()

    sql_candidate = extract_sql_from_response(response_content)

    if sql_candidate and sql_candidate != response_content:  
        # SQL found
        history.add_ai_message(f"[Agent 1 Generated SQL] {sql_candidate}")
        print(f"[Agent 1 Output - SQL]\n{sql_candidate}")
        return ("sql", sql_candidate)
    else:
        # Normal text found
        history.add_ai_message(f"[Agent 1 Normal Reply] {response_content}")
        print(f"[Agent 1 Output - Text]\n{response_content}")
        return ("text", response_content)

# Agent 2 and 3 same as before (validate and execute)...

# Modified Agent 2
# async def agent2_validate_sql(sql_query, toolbox_url="http://127.0.0.1:5000", toolset_name="my-toolset"):
#     async with ToolboxClient(toolbox_url) as client:
#         tools = await client.aload_toolset(toolset_name)
#         validate_tool = next((t for t in tools if t.name == "validate-sql"), None)
#         if not validate_tool:
#             raise RuntimeError("validate-sql tool not found in tools.yaml")

#         tool_call = {"query": sql_query}
#         result = validate_tool.invoke(tool_call)

#         history.add_ai_message(f"[Agent 2 Validation Result] {result}")
#         print("[Agent 2] Validation Result:", result)
#         return sql_query if "valid" in str(result).lower() else None


# 

async def agent2_validate_sql(sql_query, schema_info, user_question, llm, max_retries=2):
    attempt = 0
    corrected_sql = sql_query
    schema_desc = json.dumps(schema_info)

    while attempt <= max_retries:
        attempt += 1

        # -------------------------
        # Validation prompt
        # -------------------------
        validation_template = PromptTemplate(
            input_variables=["history", "schema", "question", "sql_query"],
            template=(
                "Conversation so far:\n{history}\n\n"
                "You are an expert SQL validator. "
                "Given the schema (PostgreSQL) and SQL query, determine if it is valid:\n\n"
                "Schema: {schema}\n"
                "SQL Query: {sql_query}\n\n"
                "User Question: {question}\n\n"
                "Respond in strict format:\n"
                "- If valid: 'VALID: <repeat SQL here>'\n"
                "- If invalid: 'INVALID: <reason>'"
            ),
        )

        validation_prompt = validation_template.format(
            history="\n".join([f"{m.type}: {m.content}" for m in history.messages]),
            schema=schema_desc,
            question=user_question,
            sql_query=corrected_sql,
        )

        validation_resp = llm.invoke(validation_prompt).content.strip()
        print(f"[Agent 2 Validation Attempt {attempt}] {validation_resp}")

        # -------------------------
        # Case 1: Valid SQL
        # -------------------------
        if validation_resp.startswith("VALID:"):
            valid_sql = validation_resp.replace("VALID:", "").strip()
            history.add_ai_message(f"[Agent 2] SQL validation passed.\n{valid_sql}")
            print("[Agent 2] SQL validation passed.")
            return valid_sql

        # -------------------------
        # Case 2: Invalid SQL → Correction
        # -------------------------
        elif validation_resp.startswith("INVALID:"):
            reason = validation_resp.replace("INVALID:", "").strip()
            history.add_ai_message(f"[Agent 2] Validation failed: {reason}")
            print(f"[Agent 2] Validation failed: {reason}")

            correction_template = PromptTemplate(
                input_variables=["history", "schema", "question", "bad_sql", "error"],
                template=(
                    "Conversation so far:\n{history}\n\n"
                    "The SQL query `{bad_sql}` is invalid due to: {error}\n"
                    "Schema: {schema}\n\n"
                    "Rewrite the SQL so it matches the schema, "
                    "answers the user request, and follows PostgreSQL syntax.\n"
                    "User Question: {question}\n"
                    "Corrected SQL:"
                ),
            )

            correction_prompt = correction_template.format(
                history="\n".join([f"{m.type}: {m.content}" for m in history.messages]),
                schema=schema_desc,
                question=user_question,
                bad_sql=corrected_sql,
                error=reason,
            )

            corrected_sql = llm.invoke(correction_prompt).content.strip()
            history.add_ai_message(f"[Agent 2 Corrected SQL] {corrected_sql}")
            print(f"[Agent 2 Corrected SQL] {corrected_sql}")

        else:
            # Unexpected LLM response
            history.add_ai_message(f"[Agent 2] Unexpected validation response: {validation_resp}")
            print(f"[Agent 2] Unexpected validation response: {validation_resp}")
            break

    return corrected_sql



async def agent3_execute_sql(sql_query, toolbox_url="http://127.0.0.1:5000", toolset_name="my-toolset"):
    async with ToolboxClient(toolbox_url) as client:
        tools = await client.aload_toolset(toolset_name)
        run_query_tool = next((t for t in tools if t.name == "run-query"), None)
        if not run_query_tool:
            raise RuntimeError("run-query tool not found")
        

        raw_sql = extract_sql_from_response(sql_query)
        print("SQL sent:", repr(raw_sql))


        print(f"The correct sql query: {raw_sql}")
        tool_call = {'query': raw_sql}
        result = run_query_tool.invoke(tool_call)
        history.add_ai_message(f"[Agent 3 Execution Result] {result}")
        print("[Agent 3] Execution Result:")
        print(result)
        return result

# Orchestrator with follow-up support
async def multi_llm_pipeline():
    redis_url = "redis://localhost:6379"
    toolbox_url = "http://127.0.0.1:5000"

    # Load schema info once outside the loop
    print("Loading the schema info...")
    schema_info = await load_schema(redis_url)
    print(f"Schema info loaded. Type: {type(schema_info)}")

    # Initialize LLM agents once outside the loop
    llm_agent1 = ChatVertexAI(model_name="gemini-2.0-flash-001")
    llm_agent2 = ChatOpenAI(
        model="gpt-4o-mini",
        temperature=0,
        api_key=open_ai_key  # or rely on env variable
    )

    print("Starting interactive session. Type 'exit' or 'quit' to stop.")

    while True:
        user_question = input("\nEnter query (or follow-up): ")
        if user_question.strip().lower() in {"exit", "quit"}:
            print("Ending session.")
            break

        history.add_user_message(user_question)

        # NL->SQL
        sql_query = await agent1_nl_to_sql(user_question, schema_info, llm_agent1)

        # Validate SQL
        validated_sql = await agent2_validate_sql(sql_query, schema_info, user_question, llm_agent2)

        # Execute SQL and get results
        results = await agent3_execute_sql(validated_sql, toolbox_url)

        print("\n[Final Results]")
        print(results)

        print("\n[Conversation History]")
        for m in history.messages:
            print(f"{m.type.capitalize()}: {m.content}")


if __name__ == "__main__":
    await multi_llm_pipeline()


Loading the schema info...


  await redis_client.close()


Schema info loaded. Type: <class 'list'>
Starting interactive session. Type 'exit' or 'quit' to stop.
[Agent 1 Output - SQL]
SELECT SUM(total_amount) AS last_year_revenue
FROM sales_orders
WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE - INTERVAL '1 year');
[Agent 2 Validation Attempt 1] VALID: SELECT SUM(total_amount) AS last_year_revenue  
FROM sales_orders  
WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE - INTERVAL '1 year');
[Agent 2] SQL validation passed.
SQL sent: "SELECT SUM(total_amount) AS last_year_revenue  \nFROM sales_orders  \nWHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE - INTERVAL '1 year');"
The correct sql query: SELECT SUM(total_amount) AS last_year_revenue  
FROM sales_orders  
WHERE EXTRACT(YEAR FROM order_date) = EXTRACT(YEAR FROM CURRENT_DATE - INTERVAL '1 year');
[Agent 3] Execution Result:
[{"last_year_revenue":159784.78}]

[Final Results]
[{"last_year_revenue":159784.78}]

[Conversation History

'gcloud' is not recognized as an internal or external command,
operable program or batch file.


In [4]:
# from google import genai
# from google.genai import types
# client = genai.Client(
# vertexai=True, project="gen-lang-client-0571342867", location="global",
# )
# # If your image is stored in Google Cloud Storage, you can use the from_uri class method to create a Part object.
# IMAGE_URI = "gs://generativeai-downloads/images/scones.jpg"
# model = "gemini-2.5-flash-lite-preview-06-17"
# response = client.models.generate_content(
# model=model,
# contents=[
#   "How's the weather in Bangalore?"
# ],
# )
# print(response.text, end="")

DefaultCredentialsError: Your default credentials were not found. To set up Application Default Credentials, see https://cloud.google.com/docs/authentication/external/set-up-adc for more information.

In [None]:
# os.environ['OPENAI_API_KEY']= "YOUR_OPENAI_API_KEY"



<!-- # Connecting the Database
 -->


In [None]:
#  # db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=1,include_tables=['customers','orders'],custom_table_info={'customers':"customer"})
# db = SQLDatabase.from_uri(f"postgresql+psycopg2://postgres:ashish6677@35.244.23.220/dvdrental")
# print(db.dialect)
# print(db.get_usable_table_names())
# print(db.table_info)


postgresql
['actor', 'address', 'category', 'city', 'country', 'customer', 'customers', 'film', 'film_actor', 'film_category', 'inventory', 'language', 'payment', 'rental', 'staff', 'store']

CREATE TABLE actor (
	actor_id SERIAL NOT NULL, 
	first_name VARCHAR(45) NOT NULL, 
	last_name VARCHAR(45) NOT NULL, 
	last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, 
	CONSTRAINT actor_pkey PRIMARY KEY (actor_id)
)

/*
3 rows from actor table:
actor_id	first_name	last_name	last_update
1	Penelope	Guiness	2013-05-26 14:47:57.620000
2	Nick	Wahlberg	2013-05-26 14:47:57.620000
3	Ed	Chase	2013-05-26 14:47:57.620000
*/


CREATE TABLE address (
	address_id SERIAL NOT NULL, 
	address VARCHAR(50) NOT NULL, 
	address2 VARCHAR(50), 
	district VARCHAR(20) NOT NULL, 
	city_id SMALLINT NOT NULL, 
	postal_code VARCHAR(10), 
	phone VARCHAR(20) NOT NULL, 
	last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, 
	CONSTRAINT address_pkey PRIMARY KEY (address_id), 
	CONSTRAINT fk_address_city

In [None]:
# 

In [None]:
# 

<!-- # Converting Natural Language to SQL -->

In [None]:

# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# generate_query = create_sql_query_chain(llm, db)
# query = generate_query.invoke({"question": "How many film categories are there?"})
# print(query)


SELECT COUNT("category_id") AS total_categories
FROM category;


<!-- # Running the SQL Query to get the Output -->

In [None]:

# execute_query = QuerySQLDataBaseTool(db=db)
# execute_query.invoke(query)


'[(16,)]'

<!-- # Rephrasing Answers -->

In [None]:


# answer_prompt = PromptTemplate.from_template(
#      """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

#  Question: {question}
#  SQL Query: {query}
#  SQL Result: {result}
#  Answer: """
#  )

# rephrase_answer = answer_prompt | llm | StrOutputParser()

# chain = (
#      RunnablePassthrough.assign(query=generate_query).assign(
#          result=itemgetter("query") | execute_query
#      )
#      | rephrase_answer
#  )
# chain.invoke({"question": "How many actors have more than 2 films"})


'There are 5 actors who have appeared in more than 2 films.'

In [None]:
# 