In [5]:
import os
from sqlalchemy import create_engine, Column, Integer, String, Float, ForeignKey
from sqlalchemy.orm import relationship, declarative_base, sessionmaker

Base = declarative_base()

class User(Base):
    __tablename__ = "users"
    
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    age = Column(Integer)
    email = Column(Integer, unique=True, index=True)

    orders = relationship("Order", back_populates="user")
    
class Food(Base):
    __tablename__ = "food"
    
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    price = Column(Float)
    
    orders = relationship("Order", back_populates="food")
    
class Order(Base):
    __tablename__ = "orders"
    
    id = Column(Integer, primary_key=True, index=True)
    food_id = Column(Integer, ForeignKey("food.id"))
    user_id = Column(Integer, ForeignKey("users.id"))
    
    user = relationship("User", back_populates="orders")
    food = relationship("Food", back_populates="orders")
    
DATABASE_URL = "sqlite:///example.db"

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

def init_db():
    Base.metadata.create_all(bind=engine)
    
    session = SessionLocal()
    
    users = [
        User(name="Akshay", age=26, email="akshay@gmail.com"),
        User(name="Abhi", age=20, email="Abhi@gmail.com"),
        User(name="Akhi", age=22, email="Akhi@gmail.com")
    ]
    
    session.add_all(users)
    session.commit()
    
    foods = [
        Food(name="Pizza", price=10.5),
        Food(name="Burger", price=7.9),
        Food(name="Pasta", price=8.5)
    ]
    
    session.add_all(foods)
    session.commit()
    
    orders = [
        Order(food_id=1, user_id=1),
        Order(food_id=2, user_id=2),
        Order(food_id=3, user_id=2)
    ]
    
    session.add_all(orders)
    session.commit()
    
    session.close()
    print("Database initialized and sample data added.")
    
if __name__ == "__main__":
    if not os.path.exists("example.db"):
        init_db()
    else:
        print("Database already exists. No changes made.")

Database already exists. No changes made.


In [7]:
import os
from dotenv import load_dotenv
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, END, START
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from sqlalchemy import inspect, text
from langchain_core.runnables.config import RunnableConfig

load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

class AgentState(TypedDict):
    current_user: str
    question: str
    relevance: str
    sql_query: str
    query_rows: str
    sql_error: bool
    query_result: str
    attempts: int

def get_database_schema(engine):
    inspector = inspect(engine)
    schema=""
    for table_name in inspector.get_table_names():
        print(f"Table: {table_name}")
        schema = schema + f"Table: {table_name}\n"
        for column in inspector.get_columns(table_name):
            # print(f"Column: {column['name']} - Type: {column['type']}")
            print(f"{column}")
            col_name = column['name']
            col_type = str(column['type'])
            if column.get('primary_key'):
                col_type = col_type + ", Primary Key"
                # print(col_name,col_type)
            if column.get("foreign_keys"):
                fk = list(column["foreign_keys"])[0]
                col_type = col_type + f", Foreign Key to {fk.column.table.name}.{fk.column.name}"
                # print("insode foreign key")
            schema= schema + f"- {col_name}: {col_type}\n"
        schema = schema + "\n"
        
        print("Retrived Database schema")
    # print(schema)
    
    
def get_current_user(state: AgentState, config: RunnableConfig):
    user_id = config["configurable"].get("currrent_user_id", None)
    if user_id is None:
        state["current_user"] = "User not found"
        print("No user ID is configured.")
        return state
    # print(f"Current user id: {user_id}")
    session = SessionLocal()
    
    try:
        user = session.query(User).filter(User.id == int(user_id)).first()
        # print(f"user: {user.name}")
        
        if user:
            state["current_user"] = user.name
            print(f"Current user set to {user.name}")
        else:
            state["current_user"] = "User not found"
            print("User not found in the database.")
    except Exception as e:
        state["current_user"] = "error while retrieving user"
        print(f"Error while retrieving user: {e}")
    finally:
        session.close()
    return state

class CheckRelevance(BaseModel):
    relevance: str = Field(
        description="Indicates whether the question is related to the database schema. 'relevant' or 'not relevant'"
    )
    
def check_relevance(state: AgentState, config: RunnableConfig):
    question = state["question"]
    # question = "Display me the number of orders present the table"
    schema = get_database_schema(engine)
    print(f"Checking relevance of the question: {question}")
    system = """You are an AI assistant that determines whether a given question is relevant to the following database schema.
    Schema:{schema}
    Respond with only 'relevant' or 'not relevant'.
    """.format(schema=schema)
    human = f"Question: {question}"
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human)
        ]
    )
    
    llm = ChatOpenAI(temperature=0)
    # llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")
    structured_llm = llm.with_structured_output(CheckRelevance)
    relevance_checker = check_prompt | structured_llm
    relevance = relevance_checker.invoke({})
    state['relevance'] = relevance.relevance
    print(f"Relevance determined: {state['relevance']}")
    return state
    
class ConvertToSQL(BaseModel):
  sql_query: str = Field(
      description = "The SQL query corresponding to the user's natural language question."
  )

def convert_nl_to_sql(state: AgentState, config: RunnableConfig):
    question = state['question']
    current_user = state["current_user"]
    schema = get_database_schema(engine)
    print(f"Convert question to SQL for user '{current_user}': {question}")
    system = """You are an assistant that converts an natural language questions into SQL queries based on the following schema:
    {schema}
    the current user is '{current_user}'. Ensure that all query-related data is query-related data is scoped to this user.

    Provide only SQL query without any explanation. ALias columns appropriately to match the expected keys in the result.

    for example, alias 'food.name' as 'food_name' and 'food.price' as 'price'.""".format(schema=schema, current_user=current_user)

    convert_prompt = ChatPromptTemplate(
        [
            ("systme", system),
            ("human", "Question: {question}")
        ]
    )

    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(ConvertToSQL)
    sql_generator = convert_prompt | structured_llm
    result = sql_generator.invoke({"question": question})
    state['sql_query'] = result.sql_query
    
def execute_sql(state: AgentState):
    sql_query = state["sql_query"].strip()
    session = SessionLocal()
    print(f"Executinf the sql query: {sql_query}")
    try:
        result = session.execute(text(sql_query))
        if sql_query.lower().startswith("select"):
            rows = result.fetchall()
            columns = result.keys()
            if rows:
                header = ", ".join(columns)
                state['query_rows'] = [dict(zip(columns, row)) for row in rows]
                print(f"Raw SQL Query Result: {state['query-rows']}")

                # Format the result for readability
                data = "; ".join([f"{row.get('food_name', row.get('name'))} for ${row.get('price', row.get('food_price'))}" for row in state['query_rows']])
                formatted_result = f"{header}\n{data}"
            else:
                state['query_rows'] = []
                formatted_result = "No result found."
            state["query_result"] = formatted_result
            state["sql_error"] = False
            print("SQL SELECT query executed successfully.")
        else:
            session.commit()
            state["query_result"] = "The action has been successfully completed."
            state["sql_error"] = False
            print("SQL Command executed successfully.")
    except Exception as e:
        state["query_result"] = f"Error Executing SQL Query: {str(e)}"
        state["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    finally:
        session.close()
    return state

def generate_human_readable_answer(state: AgentState):
    sql = state["sql_query"]
    result = state["query_result"]
    current_user = state["current_user"]
    query_rows = state.get("query_rows", [])
    sql_error = state.get("sql_error", False)
    print("Generating a human-readable answer.")
    system = """You are an assistant that converts SQL query results into clear, natural language responses without including any identifiers like orders, users, or food names.
    """
    if sql_error:
      generate_prompt = ChatPromptTemplate.from_messages(
          [
              ("system", system),
              (
                  "human",
                  f"""SQL Query:
                  {sql}
                  
                  Result:
                  {result}
                  
                  Formulate a clear and understandable error message in a single sentence, starting with 'Hello {current_user}, ' informing them about the issue.""" 
              ),
          ]
      )
    elif sql.lower().startswith("select"):
      if not query_rows:
          generate_prompt = ChatPromptTemplate.from_messages(
              [
                  ("system", system),
                  (
                      "huma", f"""SQL Query:
                      {sql}
                      Result;
                      {result}
                      
                      Formulate a clear and understandable answer to the original question in a single sentence, starting with 'Hello {current_user}, ' and inform them that no results were found."""
                  ),
              ]
          )
    else:
      # Handle non select queries
      generate_prompt = ChatPromptTemplate.from_messages(
          [
              ("system", system),
              (
                  "human", f"""SQL Query:
                  {sql}
                  
                  Result:
                  {result}
                  
                  Formulate a clear and understandable confirmation message in a single sentence, starting with 'Hello {current_user}, ' . confirming that your request has been successfully processed."""
              ),
          ]
      )
    llm = ChatOpenAI(temperature=0)
    human_response = generate_prompt | llm | StrOutputParser()
    answer = human_response.invoke({})
    state["query_result"] = answer
    print("Generated human-readable answer.")
    return state
    
class RewrittenQuestion(BaseModel):
  question: str = Field(description="The rewritten question.")

def regenerate_query(state: AgentState):
    question = state["question"]
    print("Regenerating the SQL query by rewriting the question.")
    system = """You are an assistant that reformulates an original question to enable mode precies SQL queries. Ensure that rewritten question is clear and unambiguous."""

    rewrite_prompt = ChatPromptTemplate(
        [
            ("system", system),
            (
                "human",
                f"Original Question: {question}\nRenformulate the question to enable more precise SQL queries, ensuring all relevant information is included. The rewritten question should be clear and unambiguous."
            )
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(RewrittenQuestion)
    rewritter = rewrite_prompt | structured_llm
    rewritten = rewritter.invoke({})
    state["question"] = rewritten.question
    state["attempts"] += 1
    print(f"Rewritten question: {state['question']}")
    return state

def generate_funny_response(state: AgentState):
    print("Generating a funny response")

def end_max_iterations(state: AgentState):
    state["query_result"] = "Please try again."
    print("Maximum attempts reached. Ending the workflow")
    return state

def relevance_router(state: AgentState):
    if state["relevance"].lower() == "relevant":
        return "convert_to_sql"
    else:
        return "generate_funny_response"

def check_attempts_router(state: AgentState):
    if state["attempts"] < 3:
        return "convert_to_sql"
    else:
        return "end_max_iterations"
  
def execute_sql_router(state: AgentState):
    if not state.get("sql_error", False):
        return "generate_human_readable_answer"
    else:
        return "regenerate_query"

In [8]:
workflow = StateGraph(AgentState)

workflow.add_node("get_current_user", get_current_user)
workflow.add_node("check_relevance", check_relevance)
workflow.add_node("convert_to_sql", convert_nl_to_sql)
workflow.add_node("execute_sql", execute_sql)
workflow.add_node("generate_human_readable_answer", generate_human_readable_answer)
workflow.add_node("regenerate_query", regenerate_query)
workflow.add_node("generate_funny_response", generate_funny_response)
workflow.add_node("end_max_iterations", end_max_iterations)

workflow.add_edge(START, "get_current_user")
workflow.add_edge("get_current_user", "check_relevance")
workflow.add_edge("check_relevance", "convert_to_sql")
workflow.add_conditional_edges(
    "check_relevance",
    relevance_router,
    {
        "convert_to_sql": "convert_to_sql",
        "generate_funny_response": "generate_funny_response"
    },
)
workflow.add_edge("convert_to_sql", "execute_sql")

workflow.add_conditional_edges(
    "execute_sql",
    execute_sql_router,
    {
        "generate_human_readable_answer": "generate_human_readable_answer",
        "regenerate_query": "regenerate_query"
    },
)

workflow.add_conditional_edges(
    "regenerate_query",
    check_attempts_router,
    {
        "convert_to_sql": "convert_to_sql",
        "max_iterations": "end_max_iterations",
    },
)

workflow.add_edge("generate_human_readable_answer", END)
workflow.add_edge("generate_funny_response", END)
workflow.add_edge("end_max_iterations", END)
app = workflow.compile()


In [3]:
app

SSLError: HTTPSConnectionPool(host='mermaid.ink', port=443): Max retries exceeded with url: /img/JSV7aW5pdDogeydmbG93Y2hhcnQnOiB7J2N1cnZlJzogJ2xpbmVhcid9fX0lJQpncmFwaCBURDsKCV9fc3RhcnRfXyhbPHA+X19zdGFydF9fPC9wPl0pOjo6Zmlyc3QKCWdldF9jdXJyZW50X3VzZXIoZ2V0X2N1cnJlbnRfdXNlcikKCWNoZWNrX3JlbGV2YW5jZShjaGVja19yZWxldmFuY2UpCgljb252ZXJ0X3RvX3NxbChjb252ZXJ0X3RvX3NxbCkKCWV4ZWN1dGVfc3FsKGV4ZWN1dGVfc3FsKQoJZ2VuZXJhdGVfaHVtYW5fcmVhZGFibGVfYW5zd2VyKGdlbmVyYXRlX2h1bWFuX3JlYWRhYmxlX2Fuc3dlcikKCXJlZ2VuZXJhdGVfcXVlcnkocmVnZW5lcmF0ZV9xdWVyeSkKCWdlbmVyYXRlX2Z1bm55X3Jlc3BvbnNlKGdlbmVyYXRlX2Z1bm55X3Jlc3BvbnNlKQoJZW5kX21heF9pdGVyYXRpb25zKGVuZF9tYXhfaXRlcmF0aW9ucykKCV9fZW5kX18oWzxwPl9fZW5kX188L3A+XSk6OjpsYXN0CglfX3N0YXJ0X18gLS0+IGdldF9jdXJyZW50X3VzZXI7CgljaGVja19yZWxldmFuY2UgLS0+IGNvbnZlcnRfdG9fc3FsOwoJY29udmVydF90b19zcWwgLS0+IGV4ZWN1dGVfc3FsOwoJZW5kX21heF9pdGVyYXRpb25zIC0tPiBfX2VuZF9fOwoJZ2VuZXJhdGVfZnVubnlfcmVzcG9uc2UgLS0+IF9fZW5kX187CglnZW5lcmF0ZV9odW1hbl9yZWFkYWJsZV9hbnN3ZXIgLS0+IF9fZW5kX187CglnZXRfY3VycmVudF91c2VyIC0tPiBjaGVja19yZWxldmFuY2U7CgljaGVja19yZWxldmFuY2UgLS4tPiBjb252ZXJ0X3RvX3NxbDsKCWNoZWNrX3JlbGV2YW5jZSAtLi0+IGdlbmVyYXRlX2Z1bm55X3Jlc3BvbnNlOwoJZXhlY3V0ZV9zcWwgLS4tPiBnZW5lcmF0ZV9odW1hbl9yZWFkYWJsZV9hbnN3ZXI7CglleGVjdXRlX3NxbCAtLi0+IHJlZ2VuZXJhdGVfcXVlcnk7CglyZWdlbmVyYXRlX3F1ZXJ5IC0uLT4gY29udmVydF90b19zcWw7CglyZWdlbmVyYXRlX3F1ZXJ5IC0uICZuYnNwO21heF9pdGVyYXRpb25zJm5ic3A7IC4tPiBlbmRfbWF4X2l0ZXJhdGlvbnM7CgljbGFzc0RlZiBkZWZhdWx0IGZpbGw6I2YyZjBmZixsaW5lLWhlaWdodDoxLjIKCWNsYXNzRGVmIGZpcnN0IGZpbGwtb3BhY2l0eTowCgljbGFzc0RlZiBsYXN0IGZpbGw6I2JmYjZmYwo=?type=png&bgColor=!white (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1020)')))

<langgraph.graph.state.CompiledStateGraph at 0x236a5f71010>

In [10]:
def get_current_user(state: AgentState, config: RunnableConfig):
  user_id = config["configurable"].get("current_user_id", None)
  if user_id is None:
      state["current_user"] = "User not found"
      print("No user ID is configured.")
      return state
  # print(f"Current user id: {user_id}")
  session = SessionLocal()
  
  try:
      user = session.query(User).filter(User.id == int(user_id)).first()
      # print(f"user: {user.name}")
      
      if user:
          state["current_user"] = user.name
          print(f"Current user set to {user.name}")
      else:
          state["current_user"] = "User not found"
          print("User not found in the database.")
  except Exception as e:
      state["current_user"] = "error while retrieving user"
      print(f"Error while retrieving user: {e}")
  finally:
      session.close()
  return state

state = AgentState()
confi = {"configurable": {"current_user_id": "1"}}
get_current_user(state, config=confi)

Current user set to Akshay


{'current_user': 'Akshay'}