<a href="https://colab.research.google.com/github/Aradhyakapil/ReAct-Agent-with-CoT-for-Text-to-SQL/blob/main/ReAct_Agent_with_CoT_for_Text_to_SQL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Install required packages (run this in Colab)

!pip install openai sqlalchemy pandas python-dotenv

Collecting python-dotenv
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
Downloading python_dotenv-1.1.1-py3-none-any.whl (20 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.1.1


In [None]:
# Complete ReAct Agent with Chain-of-Thought for Text-to-SQL
# Updated for modern OpenAI API (openai>=1.0.0)

import os
import sqlite3
import pandas as pd
from sqlalchemy import create_engine, Column, Integer, String, Float, Date, ForeignKey, text
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime, timedelta
import random
from openai import OpenAI
from typing import List, Dict, Any
import json
import re


# Set your OpenAI API key here
OPENAI_API_KEY = "sk-"  # Replace with your actual API key

# Initialize OpenAI client
client = OpenAI(api_key=OPENAI_API_KEY)

# Database setup
Base = declarative_base()

class Company(Base):
    __tablename__ = 'companies'

    id = Column(Integer, primary_key=True)
    name = Column(String(100), nullable=False)
    industry = Column(String(50))
    founded_year = Column(Integer)
    headquarters = Column(String(100))
    revenue = Column(Float)  # in millions

    employees = relationship("Employee", back_populates="company")
    products = relationship("Product", back_populates="company")
    sales = relationship("Sale", back_populates="company")

class Employee(Base):
    __tablename__ = 'employees'

    id = Column(Integer, primary_key=True)
    first_name = Column(String(50), nullable=False)
    last_name = Column(String(50), nullable=False)
    email = Column(String(100))
    department = Column(String(50))
    position = Column(String(100))
    salary = Column(Float)
    hire_date = Column(Date)
    company_id = Column(Integer, ForeignKey('companies.id'))

    company = relationship("Company", back_populates="employees")

class Product(Base):
    __tablename__ = 'products'

    id = Column(Integer, primary_key=True)
    name = Column(String(100), nullable=False)
    category = Column(String(50))
    price = Column(Float)
    launch_date = Column(Date)
    company_id = Column(Integer, ForeignKey('companies.id'))

    company = relationship("Company", back_populates="products")
    sales = relationship("Sale", back_populates="product")

class Sale(Base):
    __tablename__ = 'sales'

    id = Column(Integer, primary_key=True)
    product_id = Column(Integer, ForeignKey('products.id'))
    company_id = Column(Integer, ForeignKey('companies.id'))
    quantity = Column(Integer)
    total_amount = Column(Float)
    sale_date = Column(Date)
    customer_segment = Column(String(50))

    product = relationship("Product", back_populates="sales")
    company = relationship("Company", back_populates="sales")

class DatabaseManager:
    def __init__(self, db_url='sqlite:///company_data.db'):
        self.engine = create_engine(db_url, echo=False)
        Base.metadata.create_all(self.engine)
        Session = sessionmaker(bind=self.engine)
        self.session = Session()

    def populate_dummy_data(self):
        """Populate database with meaningful dummy company data"""

        # Clear existing data
        self.session.query(Sale).delete()
        self.session.query(Product).delete()
        self.session.query(Employee).delete()
        self.session.query(Company).delete()
        self.session.commit()

        # Companies data
        companies_data = [
            {"name": "TechCorp Solutions", "industry": "Technology", "founded_year": 2010, "headquarters": "San Francisco", "revenue": 250.5},
            {"name": "GreenEnergy Inc", "industry": "Renewable Energy", "founded_year": 2015, "headquarters": "Austin", "revenue": 180.2},
            {"name": "HealthCare Plus", "industry": "Healthcare", "founded_year": 2008, "headquarters": "Boston", "revenue": 320.8},
            {"name": "EduTech Innovations", "industry": "Education Technology", "founded_year": 2018, "headquarters": "Seattle", "revenue": 95.3},
            {"name": "FinanceFlow Ltd", "industry": "Financial Services", "founded_year": 2012, "headquarters": "New York", "revenue": 445.7}
        ]

        companies = []
        for comp_data in companies_data:
            company = Company(**comp_data)
            self.session.add(company)
            companies.append(company)

        self.session.commit()

        # Employee data
        departments = ['Engineering', 'Sales', 'Marketing', 'HR', 'Finance', 'Operations']
        positions = ['Software Engineer', 'Senior Developer', 'Sales Manager', 'Marketing Specialist',
                    'HR Coordinator', 'Financial Analyst', 'Operations Manager', 'Data Scientist']

        for company in companies:
            for _ in range(random.randint(15, 30)):  # 15-30 employees per company
                employee = Employee(
                    first_name=random.choice(['John', 'Jane', 'Michael', 'Sarah', 'David', 'Lisa', 'Robert', 'Emily']),
                    last_name=random.choice(['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Garcia', 'Miller', 'Davis']),
                    email=f"employee{random.randint(1000, 9999)}@{company.name.lower().replace(' ', '')}.com",
                    department=random.choice(departments),
                    position=random.choice(positions),
                    salary=random.uniform(50000, 200000),
                    hire_date=datetime.now() - timedelta(days=random.randint(30, 1825)),
                    company_id=company.id
                )
                self.session.add(employee)

        # Products data
        product_categories = ['Software', 'Hardware', 'Service', 'Platform', 'Tool']

        for company in companies:
            for i in range(random.randint(3, 8)):  # 3-8 products per company
                product = Product(
                    name=f"{company.name.split()[0]} Product {i+1}",
                    category=random.choice(product_categories),
                    price=random.uniform(50, 5000),
                    launch_date=datetime.now() - timedelta(days=random.randint(30, 1095)),
                    company_id=company.id
                )
                self.session.add(product)

        self.session.commit()

        # Sales data
        customer_segments = ['Enterprise', 'SMB', 'Consumer', 'Government']

        # Get all products for sales generation
        products = self.session.query(Product).all()

        for product in products:
            for _ in range(random.randint(5, 20)):  # 5-20 sales per product
                sale = Sale(
                    product_id=product.id,
                    company_id=product.company_id,
                    quantity=random.randint(1, 100),
                    total_amount=random.uniform(100, 50000),
                    sale_date=datetime.now() - timedelta(days=random.randint(1, 365)),
                    customer_segment=random.choice(customer_segments)
                )
                self.session.add(sale)

        self.session.commit()
        print("Database populated with dummy data successfully!")

    def get_schema_info(self):
        """Get database schema information"""
        schema_info = {
            'companies': {
                'columns': ['id', 'name', 'industry', 'founded_year', 'headquarters', 'revenue'],
                'description': 'Information about companies including their industry, founding year, headquarters, and revenue'
            },
            'employees': {
                'columns': ['id', 'first_name', 'last_name', 'email', 'department', 'position', 'salary', 'hire_date', 'company_id'],
                'description': 'Employee information including personal details, job information, and company association'
            },
            'products': {
                'columns': ['id', 'name', 'category', 'price', 'launch_date', 'company_id'],
                'description': 'Product information including name, category, price, and company association'
            },
            'sales': {
                'columns': ['id', 'product_id', 'company_id', 'quantity', 'total_amount', 'sale_date', 'customer_segment'],
                'description': 'Sales transactions including product, quantity, amount, and customer segment'
            }
        }
        return schema_info

    def execute_query(self, query: str):
        """Execute SQL query and return results"""
        try:
            result = self.session.execute(text(query))
            columns = result.keys()
            rows = result.fetchall()
            return {"columns": list(columns), "rows": [list(row) for row in rows]}
        except Exception as e:
            return {"error": str(e)}

class ReActAgent:
    def __init__(self, db_manager: DatabaseManager):
        self.db_manager = db_manager
        self.schema_info = db_manager.get_schema_info()

    def _call_openai(self, messages: List[Dict[str, str]], temperature: float = 0.1) -> str:
        """Call OpenAI API with error handling using new client"""
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages,
                temperature=temperature,
                max_tokens=1000
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            return f"Error calling OpenAI API: {str(e)}"

    def _generate_thought(self, question: str, previous_actions: List[str] = None) -> str:
        """Generate Chain-of-Thought reasoning"""
        schema_desc = "\n".join([
            f"- {table}: {info['description']} (columns: {', '.join(info['columns'])})"
            for table, info in self.schema_info.items()
        ])

        previous_context = ""
        if previous_actions:
            previous_context = f"\nPrevious actions taken:\n" + "\n".join(previous_actions)

        messages = [
            {
                "role": "system",
                "content": f"""You are a SQL expert with access to a company database. Think step by step about how to answer the user's question.

Database Schema:
{schema_desc}

Your task is to reason through the question step by step, considering:
1. What tables are needed
2. What columns are required
3. What joins are necessary
4. What filters or conditions are needed
5. What aggregations might be required

Respond with your thinking process in a clear, step-by-step manner."""
            },
            {
                "role": "user",
                "content": f"Question: {question}{previous_context}\n\nThink through this step by step:"
            }
        ]

        return self._call_openai(messages)

    def _generate_sql(self, question: str, thought: str) -> str:
        """Generate SQL query based on question and reasoning"""
        schema_desc = "\n".join([
            f"Table: {table}\nColumns: {', '.join(info['columns'])}\nDescription: {info['description']}\n"
            for table, info in self.schema_info.items()
        ])

        messages = [
            {
                "role": "system",
                "content": f"""You are a SQL query generator. Based on the reasoning provided, generate a precise SQL query.

Database Schema:
{schema_desc}

Rules:
- Use proper JOIN syntax for multi-table queries
- Use appropriate WHERE clauses for filtering
- Use GROUP BY and aggregate functions when needed
- Ensure column names match the schema exactly
- Return only the SQL query, no explanations"""
            },
            {
                "role": "user",
                "content": f"Question: {question}\n\nReasoning: {thought}\n\nGenerate the SQL query:"
            }
        ]

        return self._call_openai(messages)

    def _analyze_results(self, question: str, query: str, results: Dict[str, Any]) -> str:
        """Analyze query results and provide natural language response"""
        if "error" in results:
            return f"Query execution failed: {results['error']}"

        # Format results for analysis
        if results["rows"]:
            results_text = f"Query returned {len(results['rows'])} rows.\n"
            results_text += f"Columns: {', '.join(results['columns'])}\n"
            results_text += "Sample data:\n"
            for i, row in enumerate(results["rows"][:5]):  # Show first 5 rows
                results_text += f"Row {i+1}: {dict(zip(results['columns'], row))}\n"
        else:
            results_text = "Query returned no results."

        messages = [
            {
                "role": "system",
                "content": "You are a data analyst. Analyze the SQL query results and provide a clear, natural language answer to the user's question. Be specific and include relevant numbers from the data."
            },
            {
                "role": "user",
                "content": f"Original question: {question}\nSQL query used: {query}\nResults: {results_text}\n\nProvide a clear answer:"
            }
        ]

        return self._call_openai(messages)

    def process_query(self, question: str, max_iterations: int = 3) -> Dict[str, Any]:
        """Process natural language question using ReAct methodology"""
        print(f"\n🤔 Processing question: {question}")

        actions_taken = []

        for iteration in range(max_iterations):
            print(f"\n--- Iteration {iteration + 1} ---")

            # THINK: Generate reasoning
            print("💭 Thinking...")
            thought = self._generate_thought(question, actions_taken)
            print(f"Thought: {thought}")

            # ACT: Generate SQL query
            print("⚡ Acting...")
            sql_query = self._generate_sql(question, thought)
            print(f"Generated SQL: {sql_query}")

            # Clean up SQL query (remove markdown formatting if present)
            sql_query = re.sub(r'```sql\n?|```\n?', '', sql_query).strip()

            # OBSERVE: Execute query and get results
            print("👁️ Observing...")
            results = self.db_manager.execute_query(sql_query)

            actions_taken.append(f"Iteration {iteration + 1}: Generated SQL: {sql_query}")

            if "error" not in results:
                print("✅ Query executed successfully!")

                # ANALYZE: Provide natural language response
                print("📊 Analyzing results...")
                analysis = self._analyze_results(question, sql_query, results)

                return {
                    "question": question,
                    "thought_process": thought,
                    "sql_query": sql_query,
                    "results": results,
                    "analysis": analysis,
                    "iterations": iteration + 1,
                    "success": True
                }
            else:
                print(f"❌ Query failed: {results['error']}")
                actions_taken.append(f"Error: {results['error']}")

                # If it's the last iteration, return the error
                if iteration == max_iterations - 1:
                    return {
                        "question": question,
                        "thought_process": thought,
                        "sql_query": sql_query,
                        "error": results['error'],
                        "iterations": iteration + 1,
                        "success": False
                    }

        return {
            "question": question,
            "error": "Maximum iterations reached without successful query execution",
            "iterations": max_iterations,
            "success": False
        }

def demonstrate_agent():
    """Demonstrate the ReAct agent with sample queries"""
    print("🚀 Initializing ReAct Text-to-SQL Agent...")

    # Initialize database and agent
    db_manager = DatabaseManager()
    db_manager.populate_dummy_data()

    agent = ReActAgent(db_manager)

    # Sample questions to demonstrate capabilities
    sample_questions = [
        "Which companies in the technology industry have products launched after 2020?",
        "What is the average employee salary for companies with revenue over 200 million?",
        "Show me the top 5 products by total sales amount along with their company names"
    ]

    print("\n" + "="*80)
    print("🎯 REACT AGENT DEMONSTRATION")
    print("="*80)

    for i, question in enumerate(sample_questions, 1):
        print(f"\n📝 Question {i}: {question}")
        print("-" * 80)

        result = agent.process_query(question)

        if result["success"]:
            print(f"\n✅ SUCCESS (in {result['iterations']} iterations)")
            print(f"📊 Analysis: {result['analysis']}")

            # Show sample results if available
            if result["results"]["rows"]:
                print(f"\n📋 Sample Results ({min(3, len(result['results']['rows']))} rows):")
                for j, row in enumerate(result["results"]["rows"][:3]):
                    row_data = dict(zip(result["results"]["columns"], row))
                    print(f"  {j+1}. {row_data}")
        else:
            print(f"\n❌ FAILED after {result['iterations']} iterations")
            if "error" in result:
                print(f"Error: {result['error']}")

        print("\n" + "="*80)

    return agent

def interactive_mode(agent: ReActAgent):
    """Interactive mode for custom queries"""
    print("\n🎮 Interactive Mode - Ask your own questions!")
    print("Type 'quit' to exit, 'schema' to see database schema")

    while True:
        question = input("\n❓ Your question: ").strip()

        if question.lower() == 'quit':
            print("👋 Goodbye!")
            break
        elif question.lower() == 'schema':
            print("\n📊 Database Schema:")
            for table, info in agent.schema_info.items():
                print(f"  {table}: {info['description']}")
                print(f"    Columns: {', '.join(info['columns'])}")
            continue
        elif not question:
            continue

        result = agent.process_query(question)

        if result["success"]:
            print(f"\n✅ Answer: {result['analysis']}")
        else:
            print(f"\n❌ Failed to answer: {result.get('error', 'Unknown error')}")

# Main execution
if __name__ == "__main__":
    print("="*80)
    print("🤖 ReAct Agent for Text-to-SQL with Chain-of-Thought Reasoning")
    print("="*80)

    # Check if OpenAI API key is set
    if OPENAI_API_KEY == "your-openai-api-key-here":
        print("⚠️  Please set your OpenAI API key in the OPENAI_API_KEY variable")
        print("You can get your API key from: https://platform.openai.com/api-keys")
    else:
        # Run demonstration
        agent = demonstrate_agent()

        # Optionally run interactive mode
        run_interactive = input("\n🎮 Would you like to try interactive mode? (y/n): ").strip().lower()
        if run_interactive == 'y':
            interactive_mode(agent)

🤖 ReAct Agent for Text-to-SQL with Chain-of-Thought Reasoning
🚀 Initializing ReAct Text-to-SQL Agent...
Database populated with dummy data successfully!

🎯 REACT AGENT DEMONSTRATION

📝 Question 1: Which companies in the technology industry have products launched after 2020?
--------------------------------------------------------------------------------

🤔 Processing question: Which companies in the technology industry have products launched after 2020?

--- Iteration 1 ---
💭 Thinking...
Thought: To answer the question "Which companies in the technology industry have products launched after 2020?", I will follow these steps:

### Step 1: Identify the Relevant Tables
The question involves information about companies and their products. Therefore, the relevant tables are:
- **companies**: To filter companies by industry.
- **products**: To filter products based on their launch date.

### Step 2: Identify Required Columns
From the identified tables, I need the following columns:
- From th

KeyboardInterrupt: 