# Question

In [342]:
question = "وضعیت بازار چگونه است"

# Setup

## Imports

In [343]:
from langgraph.graph import StateGraph, END, START
from typing import TypedDict, List, Annotated
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
import operator
import psycopg2
from psycopg2 import sql
import re
from langchain_ollama import ChatOllama
from IPython.display import Image, display

## State

### Sample

In [344]:
# Define the state structure
class AgentState(TypedDict):
    messages: Annotated[List, operator.add]  # Accumulates all messages
    user_query: str  # Original user query
    detected_language: str  # Detected language of user input
    corrected_query: str  # Corrected/improved query
    is_relevant: bool  # Whether query is relevant to Chinook DB
    sql_query: str  # Generated SQL query
    query_result: str  # Result from database query
    final_answer: str  # Final response to user

### Custom

In [345]:
class States(TypedDict):
    messages: Annotated[List, operator.add]  # Accumulates all messages
    question: str  # Original user query
    language: str

## Prompts

### Detect Language

In [346]:
prompt_detect_language = """
Detect the language of the user's text. 
Respond with ONLY the language name in English. 
Do not include any other text, explanations, or punctuation.

Examples:
- User: "Hola cómo estás" → Spanish
- User: "Bonjour comment ça va" → French
- User: "Hello how are you" → English
- User: "Wie geht es dir" → German
- User: "你好吗" → Chinese
- User: "お元気ですか" → Japanese
- User: "Как дела" → Russian
- User: "سلام، حالت چطوره" → Persian
"""

## Model

In [347]:
# Initialize the LLM
llm = ChatOllama(model="llama3.2:3b", reasoning=False, temperature=0.1)

## Database

### Connection

In [348]:
DB_CONFIG = {
    "host": "localhost",
    "database": "chinook",
    "user": "postgres",
    "password": "chinook",
    "port": "55000"
}

# Database connection setup for PostgreSQL
def get_postgres_connection():
    """Get connection to PostgreSQL database"""
    try:
        conn = psycopg2.connect(**DB_CONFIG)
        return conn
    except Exception as e:
        raise Exception(f"Failed to connect to PostgreSQL: {e}")

### Schema

In [349]:
def get_database_schema():
    """Get the schema of Chinook database for PostgreSQL"""
    conn = get_postgres_connection()
    cursor = conn.cursor()
    
    schema_info = "# Chinook Database Schema\n\n"
    
    # Get table information
    cursor.execute("""
        SELECT table_name, table_type 
        FROM information_schema.tables 
        WHERE table_schema = 'public'
        ORDER BY table_name;
    """)
    tables = cursor.fetchall()
    
    for table in tables:
        table_name, table_type = table
        
        # Get column information
        cursor.execute("""
            SELECT column_name, data_type, is_nullable, column_default
            FROM information_schema.columns 
            WHERE table_schema = 'public' AND table_name = %s
            ORDER BY ordinal_position;
        """, (table_name,))
        columns = cursor.fetchall()
        
        # Get primary key information
        cursor.execute("""
            SELECT kcu.column_name
            FROM information_schema.table_constraints tc
            JOIN information_schema.key_column_usage kcu 
                ON tc.constraint_name = kcu.constraint_name
                AND tc.table_schema = kcu.table_schema
            WHERE tc.constraint_type = 'PRIMARY KEY' 
                AND tc.table_schema = 'public'
                AND tc.table_name = %s;
        """, (table_name,))
        primary_keys = [pk[0] for pk in cursor.fetchall()]
        
        # Get foreign key information
        cursor.execute("""
            SELECT
                kcu.column_name,
                ccu.table_name AS foreign_table_name,
                ccu.column_name AS foreign_column_name
            FROM information_schema.table_constraints AS tc
            JOIN information_schema.key_column_usage AS kcu
                ON tc.constraint_name = kcu.constraint_name
                AND tc.table_schema = kcu.table_schema
            JOIN information_schema.constraint_column_usage AS ccu
                ON ccu.constraint_name = tc.constraint_name
                AND ccu.table_schema = tc.table_schema
            WHERE tc.constraint_type = 'FOREIGN KEY' 
                AND tc.table_schema = 'public'
                AND tc.table_name = %s;
        """, (table_name,))
        foreign_keys = cursor.fetchall()
        
        schema_info += f"## Table: {table_name} ({table_type})\n"
        
        # Add columns
        schema_info += "### Columns:\n"
        for col in columns:
            column_name, data_type, is_nullable, column_default = col
            pk_indicator = " (PK)" if column_name in primary_keys else ""
            fk_info = ""
            for fk in foreign_keys:
                if fk[0] == column_name:
                    fk_info = f" → {fk[1]}({fk[2]})"
            default_info = f" DEFAULT: {column_default}" if column_default else ""
            nullable_info = " NOT NULL" if is_nullable == 'NO' else ""
            schema_info += f"  - {column_name}: {data_type}{nullable_info}{pk_indicator}{fk_info}{default_info}\n"
        
        # Add sample data count
        cursor.execute(sql.SQL("SELECT COUNT(*) FROM {}").format(sql.Identifier(table_name)))
        count = cursor.fetchone()[0]
        schema_info += f"  - Sample data: {count} rows\n\n"
    
    # Get view information
    cursor.execute("""
        SELECT table_name, view_definition
        FROM information_schema.views 
        WHERE table_schema = 'public'
        ORDER BY table_name;
    """)
    views = cursor.fetchall()
    
    if views:
        schema_info += "## Views:\n"
        for view in views:
            view_name, view_definition = view
            schema_info += f"### View: {view_name}\n"
            schema_info += f"Definition: {view_definition[:200]}...\n\n"
    
    conn.close()
    return schema_info

# Nodes

## Detetct Language

In [350]:
def detect_language(state: States) -> States:
    """Detect the langugae of the question"""

    print("⚡️ Detecting Language")

    detect_prompt = [
        SystemMessage(content=prompt_detect_language),
        HumanMessage(content=state["question"])
    ]

    language = llm.invoke(detect_prompt).content

    print(f"📡 Language Detected:\n{language}")

    return {
        "messages": [],
        "language": language,
        "question": state["question"]
    }
    
        

## Other Nodes

In [351]:
# Node 1: Detect and translate to English
def detect_and_translate(state: AgentState):
    print("Detecting language and translating to English...")
    
    # Detect language
    detect_prompt = [
        SystemMessage(content="Detect the language of this text. Respond with just the language name (e.g., 'English', 'Spanish', 'French')."),
        HumanMessage(content=state["user_query"])
    ]
    
    detected_lang = llm.invoke(detect_prompt).content
    
    # Translate to English if not already
    if detected_lang.lower() != "english":
        translate_prompt = [
            SystemMessage(content=f"Translate this {detected_lang} text to English accurately."),
            HumanMessage(content=state["user_query"])
        ]
        translated_query = llm.invoke(translate_prompt).content
    else:
        translated_query = state["user_query"]
    
    return {
        "detected_language": detected_lang,
        "corrected_query": translated_query,
        "messages": [HumanMessage(content=f"Translated to English: {translated_query}")]
    }

In [352]:
# Node 2: Correct and improve the query
def correct_query(state: AgentState):
    print("Correcting and improving the query...")
    
    schema = get_database_schema()
    
    correction_prompt = [
        SystemMessage(content=f"""You are a database expert. Based on the Chinook database schema below, 
        correct and improve the user's question to make it more precise and suitable for SQL querying.
        
        Database Schema:
        {schema}
        
        Respond with only the improved question, no additional text."""),
        HumanMessage(content=state["corrected_query"])
    ]
    
    corrected = llm.invoke(correction_prompt).content
    return {"corrected_query": corrected}

In [353]:
# Node 3: Check relevance to Chinook database
def check_relevance(state: AgentState):
    print("Checking relevance to Chinook database...")
    
    schema = get_database_schema()
    schema_summary = "Chinook database contains music store data: artists, albums, tracks, customers, invoices, employees, etc."
    
    relevance_prompt = [
        SystemMessage(content=f"""Determine if this question is relevant to the Chinook database.
        Chinook Database Content: {schema_summary}
        
        Respond with only 'RELEVANT' or 'IRRELEVANT'. No other text."""),
        HumanMessage(content=state["corrected_query"])
    ]
    
    response = llm.invoke(relevance_prompt).content.upper()
    is_relevant = "RELEVANT" in response
    
    if not is_relevant:
        return {
            "is_relevant": False,
            "final_answer": "I'm sorry, but this question doesn't seem to be related to the Chinook music database. Please ask about music, customers, invoices, artists, or other related topics."
        }
    
    return {"is_relevant": True}

In [354]:
# Node 4: Generate SQL query (updated for PostgreSQL)
def generate_sql(state: AgentState):
    print("Generating SQL query...")
    
    schema = get_database_schema()  # or get_database_schema_simple()
    
    sql_prompt = [
        SystemMessage(content=f"""You are a SQL expert. Generate a PostgreSQL query based on the Chinook database schema.
        
        Database Schema:
        {schema}
        
        Important: 
        - Use PostgreSQL syntax (e.g., ILIKE instead of LIKE for case-insensitive search)
        - Use double quotes for identifiers if needed
        - Only respond with the SQL query, no additional text or explanations.
        - Use proper JOIN syntax and avoid deprecated methods."""),
        HumanMessage(content=state["corrected_query"])
    ]
    
    sql_query = llm.invoke(sql_prompt).content
    # Clean up the SQL query (remove markdown code blocks if present)
    sql_query = re.sub(r'```sql\n?|\n?```', '', sql_query).strip()
    
    return {"sql_query": sql_query}

In [355]:
# Node 5: Execute SQL query (updated for PostgreSQL)
def execute_sql(state: AgentState):
    print("Executing SQL query...")
    
    try:
        conn = get_postgres_connection()
        cursor = conn.cursor()
        
        # Execute the query
        cursor.execute(state["sql_query"])
        results = cursor.fetchall()
        
        # Format results
        if results:
            # Get column names
            column_names = [desc[0] for desc in cursor.description]
            result_str = f"Columns: {', '.join(column_names)}\n"
            result_str += "Results:\n"
            for row in results:
                result_str += f"{row}\n"
            
            # Add row count
            result_str += f"\nTotal rows: {len(results)}"
        else:
            result_str = "No results found."
        
        conn.close()
        return {"query_result": result_str}
        
    except Exception as e:
        error_msg = f"Error executing query: {str(e)}"
        # Add the SQL query to error message for debugging
        error_msg += f"\nSQL Query: {state['sql_query']}"
        return {"query_result": error_msg}

In [356]:
# Node 6: Generate natural language response
def generate_response(state: AgentState):
    print("Generating natural language response...")
    
    response_prompt = [
        SystemMessage(content="""Based on the SQL query results, provide a clear, natural language answer to the user's original question.
        Be concise and business-focused in your response."""),
        HumanMessage(content=f"""Original question: {state['corrected_query']}
        SQL Results: {state['query_result']}
        
        Provide a helpful response:""")
    ]
    
    english_response = llm.invoke(response_prompt).content
    
    # Translate back to original language if needed
    if state["detected_language"].lower() != "english":
        translate_prompt = [
            SystemMessage(content=f"Translate this English text to {state['detected_language']} accurately."),
            HumanMessage(content=english_response)
        ]
        final_response = llm.invoke(translate_prompt).content
    else:
        final_response = english_response
    
    return {"final_answer": final_response}

In [357]:
# Node 7: Handle irrelevant queries
def handle_irrelevant_query(state: AgentState):
    print("Handling irrelevant query...")
    
    if state["detected_language"].lower() != "english":
        # Translate the irrelevant message to user's language
        translate_prompt = [
            SystemMessage(content=f"Translate this English text to {state['detected_language']}:"),
            HumanMessage(content=state["final_answer"])
        ]
        translated_response = llm.invoke(translate_prompt).content
        return {"final_answer": translated_response}
    
    return state

In [358]:
# Conditional edge functions
def is_relevant(state: AgentState):
    return state.get("is_relevant", False)

In [359]:
def sql_execution_successful(state: AgentState):
    return not state.get("query_result", "").startswith("Error")

# Workflow

## Sample

In [360]:
# Build the graph
def create_chinook_agent():
    workflow = StateGraph(States)

    workflow.add_node(detect_language.__name__, detect_language)

    workflow.add_edge(START, detect_language.__name__)
    workflow.add_edge(detect_language.__name__, END)
    
    # Add nodes
    # workflow.add_node("detect_translate", detect_and_translate)
    # workflow.add_node("correct_query", correct_query)
    # workflow.add_node("check_relevance", check_relevance)
    # workflow.add_node("generate_sql", generate_sql)
    # workflow.add_node("execute_sql", execute_sql)
    # workflow.add_node("generate_response", generate_response)
    # workflow.add_node("handle_irrelevant", handle_irrelevant_query)
    
    # # Set entry point
    # workflow.set_entry_point("detect_translate")
    
    # # Add edges
    # workflow.add_edge("detect_translate", "correct_query")
    # workflow.add_edge("correct_query", "check_relevance")
    
    # # Conditional edge after relevance check
    # workflow.add_conditional_edges(
    #     "check_relevance",
    #     is_relevant,
    #     {
    #         True: "generate_sql",
    #         False: "handle_irrelevant"
    #     }
    # )
    
    # workflow.add_edge("generate_sql", "execute_sql")
    # workflow.add_edge("execute_sql", "generate_response")
    # workflow.add_edge("generate_response", END)
    # workflow.add_edge("handle_irrelevant", END)
    
    return workflow.compile()

## Custom

In [361]:
def create_workflow():
    workflow = StateGraph(States)
    
    workflow.add_node(detect_language.__name__, detect_language)

    workflow.add_edge(START, detect_language.__name__)
    workflow.add_edge(detect_language.__name__, END)
    return workflow.compile()

# App

## Sample

In [362]:
# # Main function to test the agent

# # Create the agent
# agent = create_chinook_agent()

# initial_state = {
#     "messages": [],
#     "user_query": question,
#     "detected_language": "en",
#     "corrected_query": "",
#     "is_relevant": True,
#     "sql_query": "",
#     "query_result": "",
#     "final_answer": ""
# }

# agent.invoke(initial_state)

# try:
#     result = agent.invoke(initial_state)
#     print(f"\nFinal Answer: {result['final_answer']}")
#     print(f"\nDetected Language: {result.get('detected_language', 'Unknown')}")
#     print(f"Corrected Query: {result.get('corrected_query', 'N/A')}")
#     if result.get('sql_query'):
#         print(f"Generated SQL: {result['sql_query']}")
# except Exception as e:
#     print(f"Error: {e}")

## Custom

In [363]:

agent = create_workflow()

initiate_state = {
    "messages": [],
    "question": question,
}

agent.invoke(initiate_state)

⚡️ Detecting Language


📡 Language Detected:
Persian


{'messages': [], 'question': 'وضعیت بازار چگونه است', 'language': 'Persian'}