## CA 4 - Part 2, LLMs Spring 2025

- **Name:**
- **Student ID:**

---
#### Your submission should be named using the following format: `CA4_LASTNAME_STUDENTID.ipynb`.

---

TA Email: miladmohammadi@ut.ac.ir

##### *How to do this problem set:*

- Some questions require writing Python code and computing results, and the rest of them have written answers. For coding problems, you will have to fill out all code blocks that say `YOUR CODE HERE`.

- For text-based answers, you should replace the text that says ```Your Answer Here``` with your actual answer.

- There is no penalty for using AI assistance on this homework as long as you fully disclose it in the final cell of this notebook (this includes storing any prompts that you feed to large language models). That said, anyone caught using AI assistance without proper disclosure will receive a zero on the assignment (we have several automatic tools to detect such cases). We're literally allowing you to use it with no limitations, so there is no reason to lie!

---

##### *Academic honesty*

- We will audit the Colab notebooks from a set number of students, chosen at random. The audits will check that the code you wrote actually generates the answers in your notebook. If you turn in correct answers on your notebook without code that actually generates those answers, we will consider this a serious case of cheating.

- We will also run automatic checks of Colab notebooks for plagiarism. Copying code from others is also considered a serious case of cheating.

---

## Text2SQL

In this section, you will progressively build and evaluate multiple Text-to-SQL pipelines. You’ll start with a simple prompting-based baseline, then design a graph-based routing system using chain-of-thought and schema reasoning, and finally construct a ReAct agent that interacts with the schema via tools. Each stage demonstrates a different strategy for generating SQL from natural language using LLMs.

### Initializations

This section prepares the environment and initializes the LLM model (Gemini) to be used in later parts of the notebook.

In [2]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


#### Load API Key (2 Points)

**Task:** Load the Gemini API key stored in the `.env` file and set it as an environment variable so it can be used to authenticate API requests later.

* Use `dotenv` to load the file.
* Extract the API key with `os.getenv`.

In [None]:
# import os
# from dotenv import load_dotenv

# # Load environment variables from .env file
# load_dotenv()

# # Extract the API key
# GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

# if not GOOGLE_API_KEY:
#     raise ValueError("GOOGLE_API_KEY not found in environment variables")

# # Set the API key as an environment variable for use by langchain-google-genai
# os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

#### Create ChatModel (3 Points)

**Task:** Create an instance of the Gemini LLM using LangChain. You should configure the model with proper parameters for our task.

Note: You may use any model that supports Structured Output and Tool Use. We recommend using gemini-2.5-flash-preview-05-20 from Google AI Studio, as it offers a generous free tier.

In [None]:
# from langchain_google_genai import ChatGoogleGenerativeAI

# # Create an instance of the Gemini LLM
# llm = ChatGoogleGenerativeAI(
#     model="gemini-2.0-flash-exp",  # Using the recommended model
#     temperature=0,  # Set to 0 for more deterministic SQL generation
#     max_tokens=2048,  # Sufficient for SQL queries
#     top_p=1.0,
#     top_k=40
# )

In [4]:
from dotenv import load_dotenv
import os
from langchain.chat_models import init_chat_model

load_dotenv()
llm = init_chat_model("anthropic:claude-3-7-sonnet-latest")

### Baseline

In this section, you'll build a simple baseline pipeline that directly converts a question and schema into a SQL query using a single prompt.

#### Baseline Function (5 Points)

**Task:** Implement a function that sends a system message defining the task, and a user message containing the input question and schema. The LLM should return the SQL query formatted as: "```sql\n[query]```"

In [5]:
def run_baseline(question: str, schema: str):
    # Create system message defining the task
    system_message = """You are an expert SQL query generator. Given a natural language question and database schema, generate a syntactically correct SQL query that answers the question.

Instructions:
- Analyze the question and schema carefully
- Generate only the SQL query without any explanation
- Format your response as: ```sql\n[query]```
- Ensure the query is syntactically correct and uses only tables/columns from the provided schema
- Use proper SQL syntax and conventions"""

    # Create user message with question and schema
    user_message = f"""Question: {question}

Database Schema:
{schema}

Generate the SQL query to answer this question."""

    # Get response from LLM
    response = llm.invoke([
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
    ])
    
    sql_query = response.content
    return sql_query

#### Run and Evaluate (Estimated Run Time 5-10min)

Run your baseline function over the dataset provided.

In [6]:
from method_run import run_method
import re

def function_template(item):
    result = run_baseline(item['question'], item['schema'])
    # First try to extract query from markdown SQL block
    match = re.search(r'```sql\n(.*?)```', result, re.DOTALL)
    if match:
        query = match.group(1).strip()
    else:
        # If no markdown block found, try to extract just SQL query
        query = result.strip()
        # Remove any ```sql or ``` if present without proper formatting
        query = re.sub(r'```sql|```', '', query).strip()
    
    print(f"Question: {item['question']}")
    print(f"Schema: {item['schema']}")
    print(f"Generated SQL: {query}\n")
    
    return {**item, 'sql': query}

run_method(function_template, SLEEP_TIME=10)

#Run on mode=nano if you want to test it on a smaller dataset
#run_method(function_template, SLEEP_TIME=10, mode="nano")

File not found: data/dev_databases/student_club/database_description/event.csv
File not found: data/dev_databases/student_club/database_description/major.csv
File not found: data/dev_databases/student_club/database_description/zip_code.csv
File not found: data/dev_databases/student_club/database_description/attendance.csv
File not found: data/dev_databases/student_club/database_description/budget.csv
File not found: data/dev_databases/student_club/database_description/expense.csv
File not found: data/dev_databases/student_club/database_description/income.csv
File not found: data/dev_databases/student_club/database_description/member.csv


  0%|          | 0/18 [00:00<?, ?it/s]

Question: Find the percentage of atoms with single bond. (Evidence: single bond refers to bond_type = '-'; percentage = DIVIDE(SUM(bond_type = '-'), COUNT(bond_id)) as percentage)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT (SUM(CASE WHEN bond_type = '-' THEN 1 ELSE 0 END) * 100.0 / COUNT(bond_id)) AS percentage
FROM bond



  6%|▌         | 1/18 [00:12<03:25, 12.08s/it]

Question: Indicate which atoms are connected in non-carcinogenic type molecules. (Evidence: label = '-' means molecules are non-carcinogenic)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT DISTINCT a1.atom_id, a2.atom_id2
FROM atom a1
JOIN connected c ON a1.atom_id = c.atom_id
JOIN atom a2 ON c.atom_id2 = a2.atom_id
JOIN molecule m ON a1.molecule_id = m.molecule_id
WHERE m.label = '-'



 11%|█         | 2/18 [00:25<03:24, 12.79s/it]

Question: What is the average number of bonds the atoms with the element iodine have? (Evidence: atoms with the element iodine refers to element = 'i'; average = DIVIDE(COUND(bond_id), COUNT(atom_id)) where element = 'i')
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT AVG(atom_bonds.bond_count) AS avg_bonds_for_iodine
FROM (
    SELECT a.atom_id, COUNT(c.bond_id) AS bond_count
    FROM atom a
    LEFT JOIN connected c ON a.atom_id = c.atom_id
    WHERE a.element = 'i'
    GROUP BY a.atom_id
) AS atom_bonds



 17%|█▋        | 3/18 [00:38<03:12, 12.86s/it]

Question: List down two molecule id of triple bond non carcinogenic molecules with element carbon. (Evidence: carbon refers to element = 'c'; triple bond refers to bond_type = '#'; label = '-' means molecules are non-carcinogenic)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT DISTINCT a.molecule_id
FROM atom a
JOIN bond b ON a.molecule_id = b.molecule_id
JOIN molecule m ON a.molecule_id = m.molecule_id
WHERE a.element = 'c'
AND b.bond_type = '#'
AND m.label = '-'
LIMIT 2



 22%|██▏       | 4/18 [00:51<02:59, 12.85s/it]

Question: What are the elements of the toxicology and label of molecule TR060? (Evidence: TR060 is the molecule id; label = '+' mean molecules are carcinogenic; label = '-' means molecules are non-carcinogenic; element = 'cl' means Chlorine; element = 'c' means Carbon; element = 'h' means Hydrogen; element = 'o' means Oxygen, element = 's' means Sulfur; element = 'n' means Nitrogen, element = 'p' means Phosphorus, element = 'na' means Sodium, element = 'br' means Bromine, element = 'f' means Fluorine; element = 'i' means Iodine; element = 'sn' means Tin; element = 'pb' means Lead; element = 'te' means Tellurium; element = 'ca' means Calcium)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT a.element, m.label
FROM atom a
JOIN molecule m ON a.molecule_id = m.molecule_id
WHERE a.molecule_id = 'TR060'



 28%|██▊       | 5/18 [01:03<02:45, 12.74s/it]

Question: What are the elements for bond id TR001_10_11? (Evidence: element = 'cl' means Chlorine; element = 'c' means Carbon; element = 'h' means Hydrogen; element = 'o' means Oxygen, element = 's' means Sulfur; element = 'n' means Nitrogen, element = 'p' means Phosphorus, element = 'na' means Sodium, element = 'br' means Bromine, element = 'f' means Fluorine; element = 'i' means Iodine; element = 'sn' means Tin; element = 'pb' means Lead; element = 'te' means Tellurium; element = 'ca' means Calcium)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Generated SQL: SELECT DISTINCT a.element
FROM atom a
JOIN connected c ON a.atom_id = c.atom_id OR a.atom_id = c.atom_id2
WHERE c.bond_id = 'TR001_10_11'



 33%|███▎      | 6/18 [01:15<02:30, 12.53s/it]

Question: How many superheroes were published by Dark Horse Comics? (Evidence: published by Dark Horse Comics refers to publisher_name = 'Dark Horse Comics';)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Generated SQL: SELECT COUNT(*) 
FROM superhero s
JOIN publisher p ON s.publisher_id = p.id
WHERE p.publisher_name = 'Dark Horse Comics'



 39%|███▉      | 7/18 [01:27<02:16, 12.39s/it]

Question: What are the race and alignment of Cameron Hicks? (Evidence: Cameron Hicks refers to superhero_name = 'Cameron Hicks';)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Generated SQL: SELECT r.race, a.alignment
FROM superhero s
JOIN race r ON s.race_id = r.id
JOIN alignment a ON s.alignment_id = a.id
WHERE s.superhero_name = 'Cameron Hicks'



 44%|████▍     | 8/18 [01:39<02:01, 12.13s/it]

Question: Among the superheroes with height from 170 to 190, list the names of the superheroes with no eye color. (Evidence: height from 170 to 190 refers to height_cm BETWEEN 170 AND 190; no eye color refers to eye_colour_id = 1)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Generated SQL: SELECT s.superhero_name
FROM superhero s
WHERE s.height_cm BETWEEN 170 AND 190
AND s.eye_colour_id = 1



 50%|█████     | 9/18 [01:51<01:48, 12.06s/it]

Question: List down at least five superpowers of male superheroes. (Evidence: male refers to gender = 'Male'; superpowers refers to power_name;)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Generated SQL: SELECT DISTINCT sp.power_name
FROM superpower sp
JOIN hero_power hp ON sp.id = hp.power_id
JOIN superhero sh ON hp.hero_id = sh.id
JOIN gender g ON sh.gender_id = g.id
WHERE g.gender = 'Male'
LIMIT 5



 56%|█████▌    | 10/18 [02:03<01:36, 12.05s/it]

Question: What is the percentage of superheroes who act in their own self-interest or make decisions based on their own moral code? Indicate how many of the said superheroes were published by Marvel Comics. (Evidence: published by Marvel Comics refers to publisher_name = 'Marvel Comics'; superheroes who act in their own self-interest or make decisions based on their own moral code refers to alignment = 'Bad'; calculation = MULTIPLY(DIVIDE(SUM(alignment = 'Bad); count(id)), 100))
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Generated SQL: WITH BadHeroes AS (
    SELECT s.id, s.publisher_id
    FROM superhero s
 

 61%|██████    | 11/18 [02:16<01:27, 12.47s/it]

Question: Which publisher created more superheroes: DC or Marvel Comics? Find the difference in the number of superheroes. (Evidence: DC refers to publisher_name = 'DC Comics'; Marvel Comics refers to publisher_name = 'Marvel Comics'; if SUM(publisher_name = 'DC Comics') > SUM(publisher_name = 'Marvel Comics'), it means DC Comics published more superheroes than Marvel Comics; if SUM(publisher_name = 'Marvel Comics') > SUM(publisher_name = 'Marvel Comics'), it means Marvel Comics published more heroes than DC Comics; difference = SUBTRACT(SUM(publisher_name = 'DC Comics'), SUM(publisher_name = 'Marvel Comics'));)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (i

 67%|██████▋   | 12/18 [02:30<01:16, 12.79s/it]

Question: Who was the first one paid his/her dues? Tell the full name. (Evidence: full name refers to first_name, last_name; first paid dues refers to MIN(received_date) where source = 'Dues')
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT m.first_name, m.last_name
FROM member m
JOIN income i ON m.member_id = i.link_to_member
WHERE i.source = 'Dues'
ORDER BY i.date_received ASC
LIMIT 1



 72%|███████▏  | 13/18 [02:42<01:02, 12.45s/it]

Question: How many income are received with an amount of 50? (Evidence: amount of 50 refers to amount = 50)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT COUNT(*) FROM income WHERE amount = 50



 78%|███████▊  | 14/18 [02:53<00:48, 12.24s/it]

Question: Name the event with the highest amount spent on advertisement. (Evidence: event refers to event_name; highest amount spent on advertisement refers to MAX(spent) where category = 'Advertisement')
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT e.event_name
FROM event e
JOIN budget b ON e.event_id = b.link_to_event
WHERE b.category = 'Advertisement'
ORDER BY b.spent DESC
LIMIT 1;



 83%|████████▎ | 15/18 [03:06<00:36, 12.24s/it]

Question: Based on the total cost for all event, what is the percentage of cost for Yearly Kickoff event? (Evidence: DIVIDE(SUM(cost where event_name = 'Yearly Kickoff'), SUM(cost)) * 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT 
  (SUM(CASE WHEN e.event_name = 'Yearly Kickoff' THEN ex.cost ELSE 0 END) / SUM(ex.cost)) * 100 AS yearly_kickoff_percentage
FROM 
  expense ex
JOIN 
  budget b ON ex.link_to_budget = b.budget_

 89%|████████▉ | 16/18 [03:18<00:24, 12.35s/it]

Question: Calculate the total average cost that Elijah Allen spent in the events on September and October. (Evidence: events in September and October refers to month(expense_date) = 9 AND MONTH(expense_date) = 10)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT AVG(e.cost) as average_cost
FROM expense e
JOIN member m ON e.link_to_member = m.member_id
WHERE m.first_name = 'Elijah' 
AND m.last_name = 'Allen'
AND (MONTH(e.expense_

 94%|█████████▍| 17/18 [03:32<00:12, 12.91s/it]

Question: Find the name and date of events with expenses for pizza that were more than fifty dollars but less than a hundred dollars. (Evidence: name of event refers to event_name; date of event refers to event_date; expenses for pizza refers to expense_description = 'Pizza' where cost > 50 and cost < 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Generated SQL: SELECT e.event_name, e.event_date
FROM event e
JOIN budget b ON e.event_id = b.lin

100%|██████████| 18/18 [03:46<00:00, 12.58s/it]

Starting to compare without knowledge for ex
Process finished successfully
start calculate
                     simple               moderate             challenging          total               
count                6                    6                    6                    18                  
accuracy             66.67                100.00               83.33                83.33               
Finished evaluation






### Chain/Router

Here, you will build a more advanced system that routes the query through different paths based on question difficulty. Easier questions go straight to query generation; harder ones go through schema path extraction first.

#### Define State (5 Points)

**Task:** Define a `RouterGraphState` using `MessagesState` and `pydantic` that contains:
* The input question and schema
* The predicted difficulty level
* The extracted schema path
* The final query

In [7]:
from langgraph.graph import MessagesState
from typing import Literal, Optional

class RouterGraphState(MessagesState):
    # Input question and schema
    question: str
    schema: str
    
    # Predicted difficulty level
    question_difficulty: Optional[str] = None
    
    # Extracted schema path for complex queries
    schema_path: Optional[str] = None
    
    # Final generated query
    query: Optional[str] = None

#### Node: Analyser (5 Points)

**Task:** Build a node that:
* Accepts a question and schema
* Analyzes the difficulty (simple/moderate/challanging)
* Uses the LLM’s structured output feature to return the difficulty

**Steps**:

1. Define a Pydantic class to hold the expected structured output.
2. Use structure output mode of LLM to bind it to the model.

In [8]:
from pydantic import BaseModel

class QuestionDifficultyAnalysis(BaseModel):
    difficulty: Literal["simple", "moderate", "challenging"]
    reasoning: str

def analyser_node(state: RouterGraphState):
    # Create structured output LLM for difficulty analysis
    structured_llm = llm.with_structured_output(QuestionDifficultyAnalysis)
    
    prompt = f"""Analyze the difficulty of this Text-to-SQL question based on the following criteria:

SIMPLE: 
- Single table queries
- Basic SELECT, WHERE, ORDER BY operations
- No joins or complex aggregations

MODERATE:
- 2-3 table joins
- Basic aggregations (COUNT, SUM, AVG)
- Simple subqueries

CHALLENGING:
- Complex multi-table joins (4+ tables)
- Complex aggregations or window functions
- Nested subqueries or CTEs
- Advanced SQL features

Question: {state["question"]}
Schema: {state["schema"]}

Provide your difficulty assessment and reasoning."""

    response = structured_llm.invoke(prompt)
    
    # Update state with difficulty analysis
    state["question_difficulty"] = response.difficulty
    
    return state

#### Conditional Edge (2 Points)

**Task:** Implement a branching function that decides whether to proceed to direct query generation or schema path extraction based on the difficulty label returned by the analyser.

* If the difficulty is “easy”, go directly to query generation.
* Otherwise, extract the schema path first.

In [9]:
def is_schema_extraction_needed(state: RouterGraphState) -> Literal["schema_path_extractor", "query_generator"]:
    # If the difficulty is simple, go directly to query generation
    if state["question_difficulty"] == "simple":
        return "query_generator"
    else:
        # For moderate and challenging questions, extract schema path first
        return "schema_path_extractor"

#### Node: Schema Extractor (3 Points)

**Task:** Implement a node that takes the question and schema and extracts a join path or sequence of relevant tables from the schema based on the question.

* Use a simple prompt for this.
* Store the result in the `schema_path` field of the state.

In [10]:
def schema_path_extractor_node(state: RouterGraphState):
    prompt = f"""Given a natural language question and database schema, identify the relevant tables and their relationships needed to answer the question. Extract a schema path that shows the join sequence between tables.

Question: {state["question"]}

Database Schema:
{state["schema"]}

Instructions:
- Identify which tables are needed to answer the question
- Determine the join relationships between these tables
- Provide a concise schema path showing how tables should be connected
- Focus on the most relevant columns for joins and filtering

Provide a clear schema path with table relationships and key columns."""

    response = llm.invoke(prompt)
    
    # Store the schema path in the state
    state["schema_path"] = response.content
    
    return state

#### Node: Generator (5 Points)

**Task:** Generate the SQL query based on the question and schema.

* If a schema path is available, include it in the prompt.
* Save the output query in the `query` field of the state.


In [11]:
def query_generator_node(state: RouterGraphState):
    # Build the prompt based on whether schema path is available
    base_prompt = f"""Generate a syntactically correct SQL query to answer the following question based on the provided database schema.

Question: {state["question"]}

Database Schema:
{state["schema"]}"""

    # Add schema path information if available
    if state["schema_path"]:
        enhanced_prompt = f"""{base_prompt}

Schema Path Analysis:
{state["schema_path"]}

Use the schema path analysis to guide your query construction, ensuring proper joins and table relationships."""
    else:
        enhanced_prompt = base_prompt

    enhanced_prompt += """

Instructions:
- Generate only the SQL query without explanation
- Ensure the query is syntactically correct
- Use proper SQL syntax and conventions
- Format your response as: ```sql\n[query]```"""

    response = llm.invoke(enhanced_prompt)
    
    # Store the generated query in the state
    state["query"] = response.content
    
    return state

#### Build Graph (5 Points)

**Task:** Assemble the full routing graph using the nodes and edges you created.

In [12]:
from langgraph.graph import StateGraph, START, END

router_graph_builder = StateGraph(RouterGraphState)

# Add nodes
router_graph_builder.add_node("analyser", analyser_node)
router_graph_builder.add_node("schema_path_extractor", schema_path_extractor_node)
router_graph_builder.add_node("query_generator", query_generator_node)

# Add edges
router_graph_builder.add_edge(START, "analyser")
router_graph_builder.add_conditional_edges(
    "analyser",
    is_schema_extraction_needed,
    ["schema_path_extractor", "query_generator"]
)
router_graph_builder.add_edge("schema_path_extractor", "query_generator")
router_graph_builder.add_edge("query_generator", END)

router_graph = router_graph_builder.compile()

#### Run and Evaluate (Estimated Run Time 10-15min)

**Task:** Run your compiled routing graph on a dataset. For each question:

* Instantiate the `RouterGraphState` with the question and schema.
* Run the graph to completion.
* Extract and clean the query from the result.

Use the `run_method` function to handle iteration and timing.

In [14]:

from method_run import run_method
def run_router_graph(item):
    response = router_graph.invoke(
        RouterGraphState(
            question=item['question'],
            schema=item['schema'],
            schema_path=None,
            question_difficulty=None,
            query=None
        )
    )
    result = response["query"]
    # First try to extract query from markdown SQL block
    match = re.search(r'```sql\n(.*?)```', result, re.DOTALL)
    if match:
        query = match.group(1).strip()
    else:
        # If no markdown block found, try to extract just SQL query
        query = result.strip()
        # Remove any ```sql or ``` if present without proper formatting
        query = re.sub(r'```sql|```', '', query).strip()
    print(f"Question: {item['question']}")
    print(f"Schema: {item['schema']}")
    print(f"Question Difficulty: {response['question_difficulty']}")
    if response["schema_path"]:
        print(f"Schema Path: {response['schema_path']}")
    print(f"Generated SQL: {query}\n")
    return {**item, 'sql': query}


run_method(run_router_graph, SLEEP_TIME=30)

#Run on mode=nano if you want to test it on a smaller dataset
#run_method(run_router_graph, SLEEP_TIME=10, mode="nano")

File not found: data/dev_databases/student_club/database_description/event.csv
File not found: data/dev_databases/student_club/database_description/major.csv
File not found: data/dev_databases/student_club/database_description/zip_code.csv
File not found: data/dev_databases/student_club/database_description/attendance.csv
File not found: data/dev_databases/student_club/database_description/budget.csv
File not found: data/dev_databases/student_club/database_description/expense.csv
File not found: data/dev_databases/student_club/database_description/income.csv
File not found: data/dev_databases/student_club/database_description/member.csv


  0%|          | 0/18 [00:00<?, ?it/s]

Question: Find the percentage of atoms with single bond. (Evidence: single bond refers to bond_type = '-'; percentage = DIVIDE(SUM(bond_type = '-'), COUNT(bond_id)) as percentage)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
- `bond`: Contains information about bonds including the bond_type, which is needed to identify single bonds (where bond_type = '-')

## Observations
- The question asks for the percentage of atoms with single bond
- We need to calculate: COUNT(single bonds) / COUNT(total bonds)
- The evidence indicates that single bonds are represented as bond_type = '-'
- Only the `bond` table is needed as we're calculating a percentage based solely on bond types

## Schema Path
`bond` → Filter where bond_type = '-' and calculate percentage

This is a simple query that only requires the `

  6%|▌         | 1/18 [00:41<11:45, 41.50s/it]

Question: Indicate which atoms are connected in non-carcinogenic type molecules. (Evidence: label = '-' means molecules are non-carcinogenic)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
- `molecule`: Contains the carcinogenic label information
- `atom`: Contains atoms and their molecule associations
- `connected`: Shows which atoms are connected to each other
- `bond`: Contains bond information (though not essential for the core question)

## Schema Path
```
molecule (molecule_id, label)
    ↑
    | molecule_id = molecule_id
    ↓
atom (atom_id, molecule_id, element)
    ↑
    | atom_id = atom_id/atom_id2
    ↓
connected (atom_id, atom_id2, bond_id)
```

## Explanation
1. Start with the `molecule` table to filter for non-carcinogenic molecules (where `label = '-'`)
2. Join with the `atom` tabl

 11%|█         | 2/18 [01:24<11:19, 42.44s/it]

Question: What is the average number of bonds the atoms with the element iodine have? (Evidence: atoms with the element iodine refers to element = 'i'; average = DIVIDE(COUND(bond_id), COUNT(atom_id)) where element = 'i')
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Tables Needed
- **atom**: Contains element information to filter for iodine atoms (element = 'i')
- **connected**: Links atoms to bonds, needed to count bonds per atom
- **bond**: Not directly needed for this query, but connected via the connected table

## Join Relationships
- **atom** joins with **connected** on atom_id to find bonds associated with iodine atoms

## Schema Path
```
atom [atom_id, element='i'] → connected [atom_id, bond_id]
```

## Reasoning
To calculate the average number of bonds for iodine atoms, we need to:
1. Identify atoms w

 17%|█▋        | 3/18 [02:08<10:44, 42.96s/it]

Question: List down two molecule id of triple bond non carcinogenic molecules with element carbon. (Evidence: carbon refers to element = 'c'; triple bond refers to bond_type = '#'; label = '-' means molecules are non-carcinogenic)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables and Columns
- **molecule**: Contains molecule_id and label (where label = '-' for non-carcinogenic)
- **atom**: Contains molecule_id and element (to filter for carbon elements where element = 'c')
- **bond**: Contains molecule_id and bond_type (to filter for triple bonds where bond_type = '#')

## Schema Path
```
molecule (molecule_id, label = '-')
  |
  +--> atom (molecule_id, element = 'c') 
  |
  +--> bond (molecule_id, bond_type = '#')
```

## Join Relationships
The query requires molecules that:
1. Are non-carcinogenic

 22%|██▏       | 4/18 [02:51<10:06, 43.30s/it]

Question: What are the elements of the toxicology and label of molecule TR060? (Evidence: TR060 is the molecule id; label = '+' mean molecules are carcinogenic; label = '-' means molecules are non-carcinogenic; element = 'cl' means Chlorine; element = 'c' means Carbon; element = 'h' means Hydrogen; element = 'o' means Oxygen, element = 's' means Sulfur; element = 'n' means Nitrogen, element = 'p' means Phosphorus, element = 'na' means Sodium, element = 'br' means Bromine, element = 'f' means Fluorine; element = 'i' means Iodine; element = 'sn' means Tin; element = 'pb' means Lead; element = 'te' means Tellurium; element = 'ca' means Calcium)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: simple
Generated SQL: SELECT a.element, m.label
FROM atom a
JOIN molecule m ON a.molecule_id = m.molecule_id
WHERE a.molecule_id = 'TR060'



 28%|██▊       | 5/18 [03:29<08:55, 41.19s/it]

Question: What are the elements for bond id TR001_10_11? (Evidence: element = 'cl' means Chlorine; element = 'c' means Carbon; element = 'h' means Hydrogen; element = 'o' means Oxygen, element = 's' means Sulfur; element = 'n' means Nitrogen, element = 'p' means Phosphorus, element = 'na' means Sodium, element = 'br' means Bromine, element = 'f' means Fluorine; element = 'i' means Iodine; element = 'sn' means Tin; element = 'pb' means Lead; element = 'te' means Tellurium; element = 'ca' means Calcium)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)

Question Difficulty: moderate
Schema Path: # Schema Analysis for the Question

## Tables Needed
- `bond` - Contains the bond_id mentioned in the question (TR001_10_11)
- `connected` - Links bonds to atoms
- `atom` - Contains element information

## Relevant Columns
- `bond.bond_id` - To filter for the specific bond ID "TR001_10_11"
- `co

 33%|███▎      | 6/18 [04:13<08:26, 42.18s/it]

Question: How many superheroes were published by Dark Horse Comics? (Evidence: published by Dark Horse Comics refers to publisher_name = 'Dark Horse Comics';)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
- `superhero` - Contains the main superhero data
- `publisher` - Contains publisher information including 'Dark Horse Comics'

## Key Relationships
- `superhero.publisher_id` → `publisher.id` - Links superheroes to their publishers

## Schema Path
```
superhero JOIN publisher ON superhero.publisher_id = publisher.id
```

## Fi

 39%|███▉      | 7/18 [04:53<07:35, 41.41s/it]

Question: What are the race and alignment of Cameron Hicks? (Evidence: Cameron Hicks refers to superhero_name = 'Cameron Hicks';)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Tables Needed
- `superhero`: Contains the superhero named Cameron Hicks and references to race and alignment
- `race`: Contains race information
- `alignment`: Contains alignment information

## Relevant Columns
- `superhero.superhero_name`: To filter for "Cameron Hicks"
- `superhero.race_id`: Foreign key to race table
- `superhero.alignment_id`: Foreign key to alignment

 44%|████▍     | 8/18 [05:35<06:57, 41.71s/it]

Question: Among the superheroes with height from 170 to 190, list the names of the superheroes with no eye color. (Evidence: height from 170 to 190 refers to height_cm BETWEEN 170 AND 190; no eye color refers to eye_colour_id = 1)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Question Difficulty: simple
Generated SQL: SELECT superhero_name
FROM superhero
WHERE height_cm BETWEEN 170 AND 190
AND eye_colour_id = 1



 50%|█████     | 9/18 [06:12<06:02, 40.25s/it]

Question: List down at least five superpowers of male superheroes. (Evidence: male refers to gender = 'Male'; superpowers refers to power_name;)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
- `superhero`: Contains basic hero information including gender_id
- `gender`: Contains gender values needed to filter for 'Male'
- `hero_power`: Junction table linking heroes to their powers
- `superpower`: Contains the power_name we need to display

## Schema Path
```
superhero JOIN gender ON superhero.gender_id = gender.id
    JOIN hero_

 56%|█████▌    | 10/18 [06:55<05:28, 41.03s/it]

Question: What is the percentage of superheroes who act in their own self-interest or make decisions based on their own moral code? Indicate how many of the said superheroes were published by Marvel Comics. (Evidence: published by Marvel Comics refers to publisher_name = 'Marvel Comics'; superheroes who act in their own self-interest or make decisions based on their own moral code refers to alignment = 'Bad'; calculation = MULTIPLY(DIVIDE(SUM(alignment = 'Bad); count(id)), 100))
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (id, power_name)
hero_power (hero_id, power_id)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

Based on the question 

 61%|██████    | 11/18 [07:40<04:55, 42.20s/it]

Question: Which publisher created more superheroes: DC or Marvel Comics? Find the difference in the number of superheroes. (Evidence: DC refers to publisher_name = 'DC Comics'; Marvel Comics refers to publisher_name = 'Marvel Comics'; if SUM(publisher_name = 'DC Comics') > SUM(publisher_name = 'Marvel Comics'), it means DC Comics published more superheroes than Marvel Comics; if SUM(publisher_name = 'Marvel Comics') > SUM(publisher_name = 'Marvel Comics'), it means Marvel Comics published more heroes than DC Comics; difference = SUBTRACT(SUM(publisher_name = 'DC Comics'), SUM(publisher_name = 'Marvel Comics'));)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (i

 67%|██████▋   | 12/18 [08:32<04:30, 45.10s/it]

Question: Who was the first one paid his/her dues? Tell the full name. (Evidence: full name refers to first_name, last_name; first paid dues refers to MIN(received_date) where source = 'Dues')
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
- `member`: Contains personal information including first_name and last_name
- `income`: Contains payment records including da

 72%|███████▏  | 13/18 [09:15<03:43, 44.67s/it]

Question: How many income are received with an amount of 50? (Evidence: amount of 50 refers to amount = 50)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: simple
Generated SQL: SELECT COUNT(*) 
FROM income 
WHERE amount = 50



 78%|███████▊  | 14/18 [09:50<02:47, 41.79s/it]

Question: Name the event with the highest amount spent on advertisement. (Evidence: event refers to event_name; highest amount spent on advertisement refers to MAX(spent) where category = 'Advertisement')
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Tables Needed
- `event`: Contains event_name which is required in the output
- `budget`: Contains spent amount and category for fi

 83%|████████▎ | 15/18 [10:35<02:07, 42.56s/it]

Question: Based on the total cost for all event, what is the percentage of cost for Yearly Kickoff event? (Evidence: DIVIDE(SUM(cost where event_name = 'Yearly Kickoff'), SUM(cost)) * 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables
The question asks for the percentage of cost for "Yearly Kickoff" event compared to the total cost for all events.

Based on the e

 89%|████████▉ | 16/18 [11:21<01:27, 43.66s/it]

Question: Calculate the total average cost that Elijah Allen spent in the events on September and October. (Evidence: events in September and October refers to month(expense_date) = 9 AND MONTH(expense_date) = 10)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables and Columns
- **expense**: Contains cost information and expense dates needed for calculating average cos

 94%|█████████▍| 17/18 [12:05<00:43, 43.74s/it]

Question: Find the name and date of events with expenses for pizza that were more than fifty dollars but less than a hundred dollars. (Evidence: name of event refers to event_name; date of event refers to event_date; expenses for pizza refers to expense_description = 'Pizza' where cost > 50 and cost < 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)

Question Difficulty: moderate
Schema Path: # Schema Path Analysis

## Relevant Tables and Columns

100%|██████████| 18/18 [12:48<00:00, 42.68s/it]

Starting to compare without knowledge for ex
Process finished successfully
start calculate
                     simple               moderate             challenging          total               
count                6                    6                    6                    18                  
accuracy             66.67                100.00               83.33                83.33               
Finished evaluation






### Agent (ReAct)

Now you will implement a full ReAct agent that incrementally solves the Text-to-SQL task using tools. The agent can explore tables and columns before finalizing the query.

**You are not allowed to use 'Prebuilt Agent' of LangGraph. You have to build your own graph.**

#### Define Tools

**Task:** Define three tools for the agent to interact with the schema:
1. `get_samples_from_table`: Returns the first few rows of a table.
2. `get_column_description`: Provides a human-readable description of a specific column.
3. `execute`: Executes a SQL query.

In [36]:
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import ToolNode

from db_manager import DBManager
db_manager = DBManager()

@tool
def get_samples_from_table(table_name: str, config: RunnableConfig):
  """Gets the first few rows (samples) from a specified table.

  Args:
    table_name: The name of the table from which to fetch samples.

  Returns:
    The first few rows from the specified table.
  """
  db_name = config["configurable"].get("database_name")
  result = db_manager.get_table_head(table_name, db_name=db_name)
  return result

@tool
def get_column_description(table_name: str, column_name: str, config: RunnableConfig):
  """Provides a description for a specific column within a given table.

  Args:
    table_name: The name of the table containing the column.
    column_name: The name of the column for which to get the description.

  Returns:
    A string containing the description of the specified column.
  """
  db_name = config["configurable"].get("database_name")
  result = db_manager.get_column_description(db_name, table_name, column_name)
  return result

@tool
def execute(query: str, config: RunnableConfig):
  """Executes a given SQL query against the database.

  Args:
    query: The SQL query string to be executed.

  Returns:
    The result of the executed query. This could be a set of rows,
    a confirmation message, or an error.
  """
  db_name = config["configurable"].get("database_name")
  result = db_manager.query(query, db_name)
  return result

File not found: data/dev_databases/student_club/database_description/event.csv
File not found: data/dev_databases/student_club/database_description/major.csv
File not found: data/dev_databases/student_club/database_description/zip_code.csv
File not found: data/dev_databases/student_club/database_description/attendance.csv
File not found: data/dev_databases/student_club/database_description/budget.csv
File not found: data/dev_databases/student_club/database_description/expense.csv
File not found: data/dev_databases/student_club/database_description/income.csv
File not found: data/dev_databases/student_club/database_description/member.csv


#### Extra Tool (5+5 Bonus Points):

**Task**: Create and integrate a new custom tool into the ReAct agent. To receive credit for this part, your tool must be meaningfully different from the existing three tools and provide practical value in helping the agent generate more accurate or efficient SQL queries.

In [None]:
#YOUR CODE HERE

#### Create Tool Node

In [37]:
tools = [get_samples_from_table, get_column_description, execute]
tools_node = ToolNode(tools=tools)

#### ReAct Agent Prompt (5 Points)

**Task:** Set up the agent node with planning, tool use, and final SQL generation prompts. For writing efficient prompt you can read this link.
https://cookbook.openai.com/examples/gpt4-1_prompting_guide

In [38]:
REACT_SYS_PROMPT = """
You are an expert SQL query generation agent with access to database exploration tools. Your task is to generate accurate SQL queries by systematically analyzing the database schema and understanding the relationships between tables.

## Your Process:
1. **Understand the Question**: Carefully analyze what information is being requested
2. **Explore the Schema**: Use available tools to understand table structures, relationships, and sample data
3. **Plan Your Query**: Think through the necessary joins, filters, and aggregations
4. **Generate SQL**: Create a syntactically correct and efficient SQL query
5. **Validate**: If possible, test your query to ensure it works correctly

## Available Tools:
- `get_samples_from_table(table_name)`: Get sample rows from a table to understand data structure
- `get_column_description(table_name, column_name)`: Get detailed description of a specific column
- `execute(query)`: Execute a SQL query to test it or get results

## Guidelines:
- Always explore the schema first before writing queries
- Use sample data to understand data types and formats
- Consider relationships between tables when planning joins
- Generate clean, readable SQL with proper formatting
- Test your queries when possible to ensure correctness
- Provide only the final SQL query in ```sql``` format

## Response Format:
When you have your final query, format it as:
```sql
[your query here]
```

Begin by exploring the database schema to understand the structure before generating your query.
"""

#### Agent Node (5 Points)

**Task:** Set up the agent node with models that have binded with tools.

In [39]:
import time
from langchain_core.messages import SystemMessage, HumanMessage

def agent_node(state: MessagesState) -> MessagesState:
    time.sleep(10)

    # Bind tools
    agent_llm = llm.bind_tools(tools)

    # Always start with exactly one system message + all prior messages
    # (and drop any stray system messages in the middle)
    user_and_tool_messages = [
        m for m in state["messages"]
        if not isinstance(m, SystemMessage)
    ]
    messages = [SystemMessage(content=REACT_SYS_PROMPT)] + user_and_tool_messages

    # Invoke the LLM with tools
    response = agent_llm.invoke(messages)

    return {"messages": messages + [response]}


#### Build Graph (5 Points)

**Task:** Assemble the ReAct agent graph, connecting the agent node and tool node.

In [40]:
from langgraph.prebuilt import tools_condition
from typing_extensions import TypedDict

class ConfigSchema(TypedDict):
    database_name: str

react_builder = StateGraph(MessagesState, config_schema=ConfigSchema)

# Add nodes
react_builder.add_node("agent", agent_node)
react_builder.add_node("tools", tools_node)

# Add edges
react_builder.add_edge(START, "agent")
react_builder.add_conditional_edges(
    "agent",
    tools_condition,
    ["tools", END]
)
react_builder.add_edge("tools", "agent")

react_graph = react_builder.compile()

#### Run and Evaluate (Estimated Run Time 20min)

**Task:** Execute the ReAct agent pipeline on the dataset and collect SQL outputs.

In [41]:
from method_run import run_method
import re
def run_react_agent_with_config(item):
    question = item['question']
    schema = item['schema']
    user_prompt = f"Question: {question}\nSchema: {schema}"
    input_msg = HumanMessage(content=user_prompt)
    input_config = {"configurable": {"database_name": item['db_id']}}
    response = react_graph.invoke(MessagesState(messages=[input_msg]), config=input_config)

    for msg in response["messages"]:
        msg.pretty_print()
        
    # If last AI Message is a list of messages, we need to extract the last one
    last_msg = response["messages"][-1].content
    if isinstance(last_msg, list):
        last_msg = last_msg[-1]

    # First try to extract query from markdown SQL block
    match = re.search(r'```sql\n(.*?)```', last_msg, re.DOTALL)
    if match:
        query = match.group(1).strip()
    else:
        # If no markdown block found, try to extract just SQL query
        query = last_msg.strip()
        # Remove any ```sql or ``` if present without proper formatting
        query = re.sub(r'```sql|```', '', query).strip()

    return {**item, 'sql': query}

#Run agent on mode=nano, it's not needed to run on full dataset
run_method(run_react_agent_with_config, SLEEP_TIME=20, mode="nano")

File not found: data/dev_databases/student_club/database_description/event.csv
File not found: data/dev_databases/student_club/database_description/major.csv
File not found: data/dev_databases/student_club/database_description/zip_code.csv
File not found: data/dev_databases/student_club/database_description/attendance.csv
File not found: data/dev_databases/student_club/database_description/budget.csv
File not found: data/dev_databases/student_club/database_description/expense.csv
File not found: data/dev_databases/student_club/database_description/income.csv
File not found: data/dev_databases/student_club/database_description/member.csv


  0%|          | 0/5 [00:00<?, ?it/s]


Question: Find the percentage of atoms with single bond. (Evidence: single bond refers to bond_type = '-'; percentage = DIVIDE(SUM(bond_type = '-'), COUNT(bond_id)) as percentage)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)



You are an expert SQL query generation agent with access to database exploration tools. Your task is to generate accurate SQL queries by systematically analyzing the database schema and understanding the relationships between tables.

## Your Process:
1. **Understand the Question**: Carefully analyze what information is being requested
2. **Explore the Schema**: Use available tools to understand table structures, relationships, and sample data
3. **Plan Your Query**: Think through the necessary joins, filters, and aggregations
4. **Generate SQL**: Create a syntactically correct and efficient SQL query
5. **Validate**: If possible, test your query to ensur

 20%|██        | 1/5 [02:05<08:20, 125.22s/it]


Question: What are the elements for bond id TR001_10_11? (Evidence: element = 'cl' means Chlorine; element = 'c' means Carbon; element = 'h' means Hydrogen; element = 'o' means Oxygen, element = 's' means Sulfur; element = 'n' means Nitrogen, element = 'p' means Phosphorus, element = 'na' means Sodium, element = 'br' means Bromine, element = 'f' means Fluorine; element = 'i' means Iodine; element = 'sn' means Tin; element = 'pb' means Lead; element = 'te' means Tellurium; element = 'ca' means Calcium)
Schema: atom (atom_id, molecule_id, element)
bond (bond_id, molecule_id, bond_type)
connected (atom_id, atom_id2, bond_id)
molecule (molecule_id, label)



You are an expert SQL query generation agent with access to database exploration tools. Your task is to generate accurate SQL queries by systematically analyzing the database schema and understanding the relationships between tables.

## Your Process:
1. **Understand the Question**: Carefully analyze what information is being requeste

 40%|████      | 2/5 [03:36<05:16, 105.40s/it]


Question: Which publisher created more superheroes: DC or Marvel Comics? Find the difference in the number of superheroes. (Evidence: DC refers to publisher_name = 'DC Comics'; Marvel Comics refers to publisher_name = 'Marvel Comics'; if SUM(publisher_name = 'DC Comics') > SUM(publisher_name = 'Marvel Comics'), it means DC Comics published more superheroes than Marvel Comics; if SUM(publisher_name = 'Marvel Comics') > SUM(publisher_name = 'Marvel Comics'), it means Marvel Comics published more heroes than DC Comics; difference = SUBTRACT(SUM(publisher_name = 'DC Comics'), SUM(publisher_name = 'Marvel Comics'));)
Schema: alignment (id, alignment)
attribute (id, attribute_name)
colour (id, colour)
gender (id, gender)
publisher (id, publisher_name)
race (id, race)
superhero (id, superhero_name, full_name, gender_id, eye_colour_id, hair_colour_id, skin_colour_id, race_id, publisher_id, alignment_id, height_cm, weight_kg)
hero_attribute (hero_id, attribute_id, attribute_value)
superpower (

 60%|██████    | 3/5 [05:12<03:21, 100.98s/it]


Question: Find the name and date of events with expenses for pizza that were more than fifty dollars but less than a hundred dollars. (Evidence: name of event refers to event_name; date of event refers to event_date; expenses for pizza refers to expense_description = 'Pizza' where cost > 50 and cost < 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)



You are an expert SQL query generation agent with access to database exploration tools. Your t

 80%|████████  | 4/5 [06:25<01:29, 89.87s/it] 


Question: Based on the total cost for all event, what is the percentage of cost for Yearly Kickoff event? (Evidence: DIVIDE(SUM(cost where event_name = 'Yearly Kickoff'), SUM(cost)) * 100)
Schema: event (event_id, event_name, event_date, type, notes, location, status)
major (major_id, major_name, department, college)
zip_code (zip_code, type, city, county, state, short_state)
attendance (link_to_event, link_to_member)
budget (budget_id, category, spent, remaining, amount, event_status, link_to_event)
expense (expense_id, expense_description, expense_date, cost, approved, link_to_member, link_to_budget)
income (income_id, date_received, amount, source, notes, link_to_member)
member (member_id, first_name, last_name, email, position, t_shirt_size, phone, zip, link_to_major)



You are an expert SQL query generation agent with access to database exploration tools. Your task is to generate accurate SQL queries by systematically analyzing the database schema and understanding the relations

100%|██████████| 5/5 [09:05<00:00, 109.12s/it]

Starting to compare without knowledge for ex
Process finished successfully
start calculate
                     simple               moderate             challenging          total               
count                1                    1                    3                    5                   
accuracy             0.00                 0.00                 0.00                 0.00                
Finished evaluation




