In [296]:
from langchain_google_genai import ChatGoogleGenerativeAI
from dotenv import load_dotenv
import os

In [297]:
load_dotenv() 
api_key = os.getenv("GOOGLE_API_KEY")

In [298]:
llm  = ChatGoogleGenerativeAI(model = 'gemini-1.5-flash', api_key=api_key, temperature=0)

In [299]:
from langchain.utilities import SQLDatabase 
db_user = "root"
db_password ="root"
db_host = "localhost"
db_name = "atliq_tshirts"

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")


In [293]:
db.run("SELECT * FROM t_shirts")

"[(1, 'Levi', 'Blue', 'L', 48, 63), (2, 'Van Huesen', 'Red', 'L', 39, 33), (3, 'Van Huesen', 'White', 'M', 35, 62), (4, 'Adidas', 'Red', 'S', 29, 65), (5, 'Nike', 'Blue', 'XS', 30, 14), (6, 'Nike', 'Blue', 'S', 42, 33), (7, 'Adidas', 'Black', 'S', 12, 26), (8, 'Nike', 'Red', 'S', 50, 32), (9, 'Van Huesen', 'Blue', 'S', 34, 92), (10, 'Nike', 'White', 'S', 28, 70), (12, 'Nike', 'Blue', 'L', 42, 89), (13, 'Adidas', 'Blue', 'M', 29, 97), (14, 'Levi', 'Red', 'M', 11, 77), (15, 'Nike', 'Black', 'S', 46, 49), (16, 'Nike', 'Red', 'M', 15, 12), (18, 'Adidas', 'Blue', 'L', 22, 34), (19, 'Levi', 'Red', 'XL', 20, 34), (20, 'Nike', 'Red', 'XS', 49, 53), (22, 'Levi', 'Blue', 'M', 42, 72), (23, 'Van Huesen', 'Red', 'XL', 30, 16), (24, 'Adidas', 'Red', 'M', 28, 77), (25, 'Levi', 'Black', 'S', 29, 53), (27, 'Levi', 'Black', 'XS', 11, 21), (28, 'Nike', 'Blue', 'XL', 22, 16), (29, 'Levi', 'White', 'S', 32, 96), (30, 'Van Huesen', 'Black', 'S', 47, 97), (32, 'Levi', 'Blue', 'S', 25, 89), (33, 'Van Huesen'

In [300]:
def get_schema(_):
    schema = db.get_table_info()
    return schema

In [301]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


In [302]:
def clean_sql_string(sql_string):
    cleaned_string = sql_string.strip("```").lstrip("sql").strip()
    return cleaned_string

In [303]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [304]:
q =sql_chain.invoke({"question": "How many brands are there"})

In [305]:
print(q)

```sql
SELECT COUNT(DISTINCT brand) AS brand_count
FROM t_shirts;
```


In [306]:
q=clean_sql_string(q)
print(q)

SELECT COUNT(DISTINCT brand) AS brand_count
FROM t_shirts;


In [307]:
db.run(q)

'[(4,)]'

In [241]:
db.run("SELECT stock_quantity FROM t_shirts WHERE brand = 'Nike' AND size = 'XS' AND color = 'White';")

'[(44,)]'

In [308]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [309]:
def run_query(query):
    return db.run(query)

In [312]:
from operator import itemgetter
full_chain = (
    RunnablePassthrough.assign(query=sql_chain)
    .assign(
        schema=get_schema,
        response=lambda vars: run_query(clean_sql_string(vars["query"])), 
    )
    | prompt_response
    | llm
    | StrOutputParser()
)


In [311]:
full_chain.invoke({"question": "How many brands are there in the database"})

'There are 4 different brands of t-shirts in the database. \n'

In [313]:
full_chain.invoke({"question":"How many t-shirts do we have left in stock for nike in extra small and white color?"})

'We have 44 Nike t-shirts in stock that are extra small and white. \n'