In [None]:
%%capture --no-stderr
%pip install -U psycopg2-binary pinecone
%pip install -U langchain_community tiktoken langchain_google_genai langchainhub langchain langgraph langchain_core python-docx  docx2txt numpy<2.0

In [None]:
# Install pinecone client
# %%capture --no-stderr
%pip install -U pinecone-client

# Import and initialize
from pinecone import Pinecone, ServerlessSpec

# Initialize Pinecone client
pc = Pinecone(api_key="pcsk_tkNHX_QZqQ5M1BUA3HUGznfrZyRL27MxQLMKS5uPXnjqaLE4AycjE5qtYNp3qZchMyNYu")  # Replace with your actual key

# List all indexes
indexes = pc.list_indexes()
# print("Available indexes:", indexes)

# Describe your index
index_name = "table-index"  # Replace with your actual index name
index = pc.Index(index_name)

# description = pc.describe_index(index_name)
# print("Namespace metadata:", description)

vector_ids = []  # Replace with your actual vector IDs

for i in range(1, 56):
    vector_ids.append(str(i))

# Fetch vectors
response = index.fetch(ids=vector_ids, namespace="column-details-university")


In [None]:
for ids in vector_ids:
    print(f"Vector ID: {ids}")
    print("Fetched vectors:", response.vectors[ids].to_dict()["metadata"]["chunk_text"])
    print("-------------------------------------------------------")

In [None]:
from dotenv import load_dotenv
from langgraph.graph import StateGraph, START, END
import os

load_dotenv()

google_api_key = os.getenv("GOOGLE_API_KEY")
host = os.getenv("DATABASE_HOST")
port = os.getenv("DATABASE_PORT")
user = os.getenv("DATABASE_USER")
password = os.getenv("DATABASE_PASSWORD")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
index_host = os.getenv("PINECONE_HOST")

In [3]:
from langchain_google_genai import ChatGoogleGenerativeAI
from pinecone import Pinecone

llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash",google_api_key=google_api_key)

pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index(host=index_host)

In [4]:
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from typing import Annotated, Dict, List, Optional, Sequence, Set
from dataclasses import dataclass, field, replace
from typing_extensions import TypedDict
import re

@dataclass
class InputState:
    """Defines the input state for the agent, representing a narrower interface to the outside world.

    This class is used to define the initial state and structure of incoming data.
    """
    messages : Annotated[Sequence[AnyMessage], add_messages] = field(default_factory=list)
    tables : List[str] = field(default_factory=list)

@dataclass
class AgentState:
    remaining_datafetch: int = 1
    remaining_querygen: int = 3
    question : str = ""
    rewritten_question : str = ""
    updated_question : str = ""
    relevant_queries : List[str] = field(default_factory=list)
    relevant_tables : List[str] = field(default_factory=list)
    relevant_columns : List[str] = field(default_factory=list)
    already_seen_chunk_column: Set[str] = field(default_factory=set)
    already_seen_chunk_table: Set[str] = field(default_factory=set)
    error_message : str = ""
    explanation : str = ""
    is_sufficient_data : bool = False
    sql_query : str = ""
    query_executed_successfully : bool = False
    result : List[str] =  field(default_factory=list)

@dataclass
class OutputState(TypedDict):
    """Defines the output state for the agent, representing the expected response structure.

    This class is used to define the output format and structure of the agent's response.
    """
    sql_query : str = ""
    explanation : str = ""
    used_tables : List[str] = field(default_factory=list)
    used_columns : List[str] = field(default_factory=list)

In [5]:
##### Question rewriting and generation format class ####
from pydantic import BaseModel, Field
from typing import List

class TableColumnPlan(BaseModel):
    table_name: str = Field(
        ...,
        description="The name of the table that will be used in the SQL query."
    )
    expected_columns: List[str] = Field(
        ...,
        description=(
            "A list of descriptions of the types of data expected from this table. "
            "Do not use exact column names. Instead, describe what kind of information is needed. "
            "Example: 'user ID of the employee', 'Name of the employee', 'name of the skill', 'company identifier', 'Price of the item',etc. This type of short description"
        )
    )
    purpose: str = Field(
        ...,
        description="A short explanation of why this table is needed in the context of answering the query."
    )

class QueryRewritePlan(BaseModel):
    original_question: str = Field(
        ...,
        description="The user's original natural language question."
    )
    rewritten_question: str = Field(
        ...,
        description=(
            "A rewritten version of the original question that clearly explains the intent, "
            "mentions relevant entities, and makes relationships between tables obvious. "
            "This should not be a SQL query, but a detailed version of the natural language question."
        )
    )
    updated_question: str = Field(
        ...,
        description=(
            "An enhanced version of the original question that explicitly mentions what the user "
            "expects to see in the result — including useful fields like names, IDs, totals, etc. "
            "If you think it required filtering or sorting then also mention that data on which you want to filter or sort. "
            "Keep it in plain natural language (no SQL terms), and do not mention specific table or column names."
        )
    )


In [6]:
from typing import Literal
from pydantic import BaseModel, Field

class SQLQueryResponse(BaseModel):
    """
    Structured response for SQL query generation.
    """
    sql_query: str = Field(
        ...,
        description="""The generated SQL query string. the SQL query string wrapped in backticks like:
        ```sql
        SELECT * FROM ... WHERE department = 'Science';
        ```
        If insufficient data, write N/A instead of query."""
    )
    explanation: str = Field(
        ...,
        description="Step-by-step explanation of how the query was created, or if insufficient data, explain what is missing."
    )
    is_sufficient_data: bool = Field(
        ...,
        description="True if all required tables and columns were provided to generate a valid query, else False."
    )

In [7]:
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage

def rewrite_question(state : InputState)-> AgentState:
    """Rewrites the user's question to make it more explicit and detailed for SQL planning."""
    question = state.messages[-1].content if state.messages else ""
    tables = state.tables if state.tables else []
    
    prompt = PromptTemplate.from_template("""
    You are a SQL planning assistant.

    Given a user's natural language question:
    1. Rewrite it in a clear, explicit and detailed form that makes all the necessary table and column relationships obvious.
    2. Make sure the rewritten version **asks to display all useful columns**
    3. If the original question is vague (e.g., “top students”), clarify the logic (e.g., “students with the highest average grade”).
    4. Use terminology similar to column or table names when possible (e.g., "student name" instead of "who" or "person").
    5. Identify which tables are required.
    6. For each table, list the descriptions of columns you expect to use and why.
    7. A form that can help a retriever or another AI system better identify relevant tables or columns.
                
    Return your result in the required structured format.
                                        
    ## This are the tables : {table}


    Original Question: {question}
    """)

    
    question_rewrite_chain = prompt | llm.with_structured_output(QueryRewritePlan)
    result = question_rewrite_chain.invoke({
        "question": question,
        "table": tables
    })

    return AgentState(
            remaining_datafetch=1,  # default value
            remaining_querygen=3,   # default value
            question=question,
            rewritten_question=result.rewritten_question,
            updated_question=result.updated_question,
            relevant_queries=[],
            relevant_tables=[],
            relevant_columns=[],
            already_seen_chunk_column=set(),
            already_seen_chunk_table=set(),
            error_message="",
            explanation="",
            is_sufficient_data=False,
            sql_query="",
            query_executed_successfully=False,
            result=[]
        )


In [8]:
def get_relevant_queries(state : AgentState)->AgentState:
    # print(state)

    query_fetch = index.search(
        namespace="query-example-university", 
        query={
            "inputs": {"text":  state.rewritten_question}, 
            "top_k": 7
        },
    )

    relevant_queries = []
    for docs in query_fetch.result.hits:
        relevant_queries.append(docs.fields.get("chunk_text", ""))
    
    # print(relevant_queries)

    state.relevant_queries = relevant_queries

    return state
    

In [9]:
def get_table_and_columns(state : AgentState)->AgentState:

    table_fetch = index.search(
        namespace="table-details-university", 
        query={
            "inputs": {"text":  state.rewritten_question}, 
            "top_k": 5
        },
    )
    # print(results.result.hits)

    column_fetch = index.search(
        namespace="column-details-university", 
        query={
            "inputs": {"text":  state.rewritten_question}, 
            "top_k": 10
        },
    )

    relevant_tables = []
    relevant_columns = []


    for docs in table_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_table:
            relevant_tables.append(docs.fields)
            state.already_seen_chunk_table.add(chunkid)

    for docs in column_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_column:
            relevant_columns.append(docs.fields.get("chunk_text", ""))
            state.already_seen_chunk_column.add(chunkid)

    # print(relevant_tables)
    # print(relevant_columns)
    state.relevant_tables = relevant_tables
    state.relevant_columns = relevant_columns
    
    return state


In [10]:
from langchain.prompts import ChatPromptTemplate

def extract_sql_block(text: str) -> str:
    match = re.search(r"```sql\s+(.*?)```", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""


def generate_sql_query(state : AgentState)->AgentState:


    prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a helpful AI that generates SQL queries and explains them."""),
    ("human",
    """
    You are a helpful AI assistant designed to generate SQL queries and explain them clearly.

    Your task is to generate a **valid SQL query** and a **step-by-step explanation** based on:
    1. A natural language question from the user.
    2. A list of relevant tables.
    3. A list of relevant columns with metadata (including descriptions, types, and foreign key relationships).
    4. A few-shot list of example SQL queries with their corresponding questions (for context; they may or may not be directly related).

    ---


    **Guidelines:**
    - DON'T ASSUME ANYTHING BY YOURSELF!!!! 
    - Use **only** the tables and columns provided.
    - When filtering by a string (e.g., using `WHERE`), **always wrap the string in single quotes** `'like this'` — never use double quotes `"like this"`.
    - Example: `WHERE department = 'Physics'` ✅  
               `WHERE department = "Physics"` ❌

    - Use appropriate **JOINs** by identifying how foreign keys connect one table to another. Multiple joins across related tables are allowed and encouraged if needed to fulfill the query.
    - Use **aggregation** (such as `COUNT`, `AVG`, `SUM`, etc.) when it is required to summarize data as part of answering the question.
    - Always analyze if multiple tables need to be joined to access necessary information (e.g., join User → Certificate → Institute to find users with certificates from a specific institute).
    - Columns come with semantic descriptions; **infer intent** from those descriptions even if exact wording doesn't match the question.
    

    ### Natural Language Question:
    {user_question}

    ### Relevant Tables:
    {relevant_tables}

    ### Relevant Columns:
    {relevant_columns}

    ### Few-shot Examples:
    {few_shot_examples}

    ---

    Return your answer in this structured format:
    - `sql_query`: the final SQL query string (wrapped in backticks, like `SELECT * ...`) (or `"N/A"` if insufficient data)
    - `explanation`: detailed explanation of how the query was constructed or what information is missing
    - `is_sufficient_data`: Boolean indicating whether the provided data is sufficient to construct a valid SQL query
    """)
    ])

    sql_answer_chain = prompt | llm.with_structured_output(SQLQueryResponse)

    response = sql_answer_chain.invoke({
        "user_question": state.updated_question,
        "relevant_tables": state.relevant_tables,
        "relevant_columns": state.relevant_columns,
        "few_shot_examples": state.relevant_queries,
    })

    # Add these debug lines
    # print(f"Response type: {type(response)}")
    # print(f"Response: {response}")
    # print(f"Explanation type: {type(response.explanation)}")
    # print(f"Explanation value: {repr(response.explanation)}")


    raw_sql = extract_sql_block(response.sql_query)

    state.sql_query = raw_sql
    state.explanation = response.explanation
    state.is_sufficient_data = response.is_sufficient_data

    return state

In [11]:
def is_query_generated(state : AgentState)->Literal["get_more_table_column", "execute_query", END]:
    if(state.remaining_datafetch > 0 and state.is_sufficient_data == False):
        return "get_more_table_column"
    elif(state.is_sufficient_data == True):
        return "execute_query"
    else:
        return END

In [12]:
def get_more_table_column(state : AgentState)->AgentState:
    table_fetch = index.search(
        namespace="table-details-university", 
        query={
            "inputs": {"text":  state.explanation}, 
            "top_k": 3
        },
    )
    # print(results.result.hits)

    column_fetch = index.search(
        namespace="column-details-university", 
        query={
            "inputs": {"text":  state.explanation}, 
            "top_k": 5
        },
    )

    relevant_tables = []
    relevant_columns = []


    for docs in table_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_table:
            relevant_tables.append(docs.fields)
            state.already_seen_chunk_table.add(chunkid)

    for docs in column_fetch.result.hits:
        chunkid = docs._id

        if chunkid not in state.already_seen_chunk_column:
            relevant_columns.append(docs.fields.get("chunk_text", ""))
            state.already_seen_chunk_column.add(chunkid)

    # print(relevant_tables)
    # print(relevant_columns)
    state.remaining_datafetch -= 1
    state.relevant_tables = relevant_tables
    state.relevant_columns = relevant_columns
    
    return state

In [13]:
from dataclasses import replace
import psycopg2
def execute_query(state : AgentState)->AgentState:

    database = "university_db"
    sql_query = state.sql_query
    result = []
    error_message = ""

    try:
        # Establish connection
        conn = psycopg2.connect(
            host=host,
            port=port,
            database=database,
            user=user,
            password=password
        )

        cur = conn.cursor()
        query = sql_query
        cur.execute(query)
        
        column_names = [desc[0] for desc in cur.description]
        
        rows = cur.fetchall()

        result = [dict(zip(column_names, row)) for row in rows]
        state.query_executed_successfully = True

        cur.close()
        conn.close()

    except Exception as e:
        error_message = e
    
    error_str = str(error_message)

    state.result = result
    state.error_message = error_str

    return state

 

In [14]:
def is_sqlquery_right(state : AgentState)->Literal["regenerate_query", "generate_chart_insight", END]:
    if(state.query_executed_successfully == True):
        return "generate_chart_insight"
    elif(state.remaining_querygen > 0):
        return "regenerate_query"
    else:
        return END

In [15]:

def generate_chart_insight(state: AgentState) -> AgentState:
    print("Entering generate_chart_insight.....")

    if not state.query_result or not state.query_result.rows:
        return replace(state, insight="No data to visualize.")

    df = pd.DataFrame(state.query_result.rows, columns=state.query_result.columns)
    question = state.question if hasattr(state, "question") else "Unknown"


    prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a helpful data visualization assistant."""),
    ("human",
    """
        You are a helpful data visualization assistant.
        Your job is to decide whether a chart should be generated based on the user's question, column schema, and sample data.

        ## Your Responsibilities:
        1. Analyze the user's intent from the natural language question.
        2. Determine if a chart will be helpful to visualize the answer.
        3. If yes, choose a suitable chart type from this list:
        - bar, line, pie, histogram, scatter, grouped_bar, multi_line
        4. Select one column for the X-axis and one for the Y-axis (if applicable).
        - Not all columns must be used.
        5. If no chart is suitable, clearly explain why (e.g., text-based result, too few rows, not numeric data, etc.)

        ### Input:
        Question: {question}

        Schema:
        {schema}

        Sample Rows:
        {sample}

        ### Output format:
        - chart_type: one of the allowed types, or leave empty if no chart is suitable
        - x_axis: column name for X-axis (if applicable)
        - y_axis: column name for Y-axis (if applicable)
        - reason: why this chart is appropriate OR why it is not possible
    """)
    ])

    schema = [
        {"column_name": col, "dtype": str(dtype)}
        for col, dtype in zip(df.columns, df.dtypes)
    ]
    sample = df.head(3).to_dict(orient="records")

    print(schema)
    print(sample)

    llm_chart_chain = prompt | llm.with_structured_output(ChartSuggestion)
    
    suggestion =  llm_chart_chain.invoke({
        "question": question,
        "schema": schema,
        "sample": sample
    })

    # suggestion = get_chart_suggestion(df, question)

    return replace(
        state,
        chart_type=suggestion.chart_type,
        x_axis=suggestion.x_axis,
        y_axis=suggestion.y_axis,
        insight=suggestion.reason
    )


In [16]:
from langchain.prompts import ChatPromptTemplate
from typing import Optional
from pydantic import BaseModel, Field

class CorrectedQuery(BaseModel):
    corrected_sql: str = Field(
        ...,
        description=(
            "The revised SQL query that resolves the error and correctly answers the original question. "
            "If the error is due to incorrect or unknown table/column names, set this to 'N/A'."
        )
    )
    reason_for_fix: Optional[str] = Field(
        None,
        description=(
            "A clear explanation of what was wrong with the original query. "
            "If the fix was not possible due to invalid names (e.g., missing table or column), "
            "list which table/column names were invalid and describe what kind of table/column is needed. "
            "This will be used to re-fetch relevant metadata."
        )
    )




def regenerate_query(state : AgentState)->AgentState:

    prompt = ChatPromptTemplate.from_messages([
    ("system", """You are a helpful AI that generates SQL queries and explains them."""),
    ("human",
    """
        You are a highly skilled SQL assistant helping to correct SQL queries.

        You are given:
        1. An **original user question**
        2. A previously generated SQL query that caused an execution error.
        3. The **error message** returned by the database.
        4. A brief **explanation of the query logic** that was originally followed.

        Your task is to:
        - Analyze the error message and understand what went wrong.
        - Use the explanation of the original query to understand the user's intent.
        - Fix the SQL query accordingly.
        - Ensure that the new query uses correct table and column names, and valid SQL syntax.
        - Do **not** repeat the same structural mistake.
        - If the query failed due to incorrect table or column names (that don't exist), DO NOT guess. Instead:
            - Set `corrected_sql` to "N/A"
            - In `reason_for_fix`, explain which table or column was invalid and what kind of table/column is needed.
                (e.g., "The column `student_score` does not exist. A numeric column representing student performance is needed.")


        ### Original User Question:
        {question}

        ### Original SQL Query:
        {original_sql}

        ### Query Explanation:
        {explanation}

        ### Error Message:
        {error}

    """)
    ])

    query_regenerate_chain = prompt | llm.with_structured_output(CorrectedQuery)

    response = query_regenerate_chain.invoke({
        "question": state.updated_question,
        "original_sql": state.sql_query,
        "explanation": state.explanation,
        "error": state.error_message,
    })

    state.remaining_querygen -= 1

    if(response.corrected_sql == "N/A"):
        state.explanation = response.reason_for_fix

    state.sql_query = response.corrected_sql 
    return state


In [17]:
def can_fix_sqlerror(state : AgentState)->Literal["get_more_table_column", "execute_query"]:
    if(state.sql_query == "N/A"):
        return "get_more_table_column"
    else:
        return "execute_query"
    

In [18]:
from IPython.display import Image, display


workflow = StateGraph(AgentState, input=InputState)

#Add Nodes
workflow.add_node("rewrite_question", rewrite_question)
workflow.add_node("get_relevant_queries", get_relevant_queries)
workflow.add_node("get_relevant_table_column", get_table_and_columns)
workflow.add_node("generate_sql_query", generate_sql_query)
workflow.add_node("get_more_table_column", get_more_table_column)
workflow.add_node("execute_query", execute_query)
workflow.add_node("regenerate_query",regenerate_query)
workflow.add_node("generate_chart_insight", generate_chart_insight)

# Add Edges
workflow.add_edge(START,"rewrite_question")
workflow.add_edge("rewrite_question","get_relevant_queries")
workflow.add_edge("get_relevant_queries", "get_relevant_table_column")
workflow.add_edge("get_relevant_table_column", "generate_sql_query")
workflow.add_conditional_edges("generate_sql_query", is_query_generated)
workflow.add_edge("get_more_table_column", "generate_sql_query")
workflow.add_conditional_edges("execute_query",is_sqlquery_right)
workflow.add_conditional_edges("regenerate_query",can_fix_sqlerror)
workflow.add_edge("generate_chart_insight", END)

graph = workflow.compile()
# View
# display(Image(graph.get_graph().draw_mermaid_png()))



In [None]:
# ans = graph.get_graph().draw_mermaid()
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
message = [HumanMessage(content= "What is the number of students in each batch year wise?")]
tables = ["students", "courses", "professors", "departments", "enrollments", "classrooms", "schedules", "grades", "assignments", "submissions", "clubs", "club_members"]

input_state = InputState(messages=message, tables=tables)


In [None]:
ans = graph.invoke(input=input_state)

In [81]:
# ans = graph.invoke(input=input_state)
print(ans)

{'remaining_datafetch': 1, 'remaining_querygen': 3, 'question': 'What is the number of students in each batch year wise?', 'rewritten_question': 'What is the count of students for each batch year?', 'updated_question': 'I need the count of students grouped by batch year.  The result should show the batch year and the number of students in each batch.', 'relevant_queries': ['Question: Count total number of students.\nSQL query: SELECT COUNT(*) FROM students;', 'Question: Get all students who enrolled after 2020.\nSQL query: SELECT first_name, last_name FROM students WHERE enrollment_year > 2020;', 'Question: Get email addresses of students who joined in 2023.\nSQL query: SELECT email FROM students WHERE enrollment_year = 2023;', 'Question: List departments and number of courses, professors, and students in each.\nSQL query: SELECT d.department_name, (SELECT COUNT() FROM courses WHERE department_id = d.department_id) AS total_courses, (SELECT COUNT() FROM professors WHERE department_id =