In [1]:
from typing_extensions import TypedDict

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

In [2]:
from langchain.utilities import SQLDatabase

db_user = "root"
db_password = "root123"
db_host = "localhost"
db_name = "atliq_tshirts"

try:
    # Attempt to create a database connection
    db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}", sample_rows_in_table_info=3)
    print("DATABASE CONNECTED SUCCESSFULLY")
except Exception as e:
    # Print failure message and the specific error
    print("DATABASE CONNECTION FAILED")
    print(f"Error: {str(e)}")
    exit(1)

DATABASE CONNECTED SUCCESSFULLY


In [3]:
print(db.table_info)


CREATE TABLE discounts (
	discount_id INTEGER NOT NULL AUTO_INCREMENT, 
	t_shirt_id INTEGER NOT NULL, 
	pct_discount DECIMAL(5, 2), 
	PRIMARY KEY (discount_id), 
	CONSTRAINT discounts_ibfk_1 FOREIGN KEY(t_shirt_id) REFERENCES t_shirts (t_shirt_id), 
	CONSTRAINT discounts_chk_1 CHECK ((`pct_discount` between 0 and 100))
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from discounts table:
discount_id	t_shirt_id	pct_discount
1	1	10.00
2	2	15.00
3	3	20.00
*/


CREATE TABLE t_shirts (
	t_shirt_id INTEGER NOT NULL AUTO_INCREMENT, 
	brand ENUM('Van Huesen','Levi','Nike','Adidas') NOT NULL, 
	color ENUM('Red','Blue','Black','White') NOT NULL, 
	size ENUM('XS','S','M','L','XL') NOT NULL, 
	price INTEGER, 
	stock_quantity INTEGER NOT NULL, 
	PRIMARY KEY (t_shirt_id), 
	CONSTRAINT t_shirts_chk_1 CHECK ((`price` between 10 and 50))
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

/*
3 rows from t_shirts table:
t_shirt_id	brand	color	size	price	stock

In [4]:
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM t_shirts LIMIT 10;"))

mysql
['discounts', 't_shirts']
[(1, 'Levi', 'Blue', 'M', 21, 82), (2, 'Van Huesen', 'Blue', 'XL', 15, 71), (3, 'Adidas', 'White', 'M', 46, 88), (4, 'Nike', 'White', 'S', 49, 84), (5, 'Van Huesen', 'White', 'XS', 47, 53), (6, 'Nike', 'Black', 'L', 14, 65), (7, 'Nike', 'White', 'XS', 49, 48), (8, 'Levi', 'Red', 'L', 26, 78), (9, 'Nike', 'Black', 'XS', 30, 45), (10, 'Levi', 'Red', 'M', 41, 51)]


### **SETTING UP THE API  KEY**

In [5]:
import getpass
import os

if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")
query_prompt_template.messages[0].pretty_print()

In [6]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1



In [7]:
from typing_extensions import Annotated


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [8]:
query = "How many t-shirts do we have left for Nike in extra small and white color?"
sql = write_query({"question": query})

In [15]:
print(sql.get("query"))

SELECT stock_quantity FROM t_shirts WHERE brand = 'Nike' AND size = 'XS' AND color = 'White' LIMIT 10;


In [15]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [19]:
answer = execute_query({"query": sql})

In [18]:
def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    response = llm.invoke(prompt)
    return {"answer": response.content}

In [32]:
result_str = answer['result']

# Convert the string to an actual Python object
import ast  # Safely evaluates the string as a Python literal
parsed_result = ast.literal_eval(result_str)

# Access the first value (48)
value = parsed_result[0][0]

# Print or return the value
print(value)

48


In [28]:
generate_answer({"question": query, "query": sql, "result": "10"})

{'answer': 'We have 10 t-shirts left for Nike in extra small and white color.'}