In [1]:
from dotenv import load_dotenv
import os 
load_dotenv()
api_key= os.environ.get("GOOGLE_API_KEY")
print("Load API key:",api_key)



In [2]:
import pandas as pd
from sqlalchemy import create_engine
from langchain_google_genai import GoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda
from langchain_core.output_parsers import StrOutputParser

# Connect to MySQL
db_uri = "mysql+pymysql://root:Jaiswar10@localhost:3306/new_journey_2025"
db = SQLDatabase.from_uri(database_uri=db_uri, include_tables=["shop"])
engine = create_engine(db_uri)

# Create LLM
llm = GoogleGenerativeAI(model="gemini-2.0-flash", temperature=1.7)

sql_prompt = PromptTemplate(
    input_variables=["input", "table_info", "dialect"],
    template="""
You are an expert in converting natural language questions to SQL queries for MySQL.
Rules:
- Do not use markdown formatting (no ```sql ... ```).
- Only return the raw SQL query.
- Use only the provided tables: {table_info}

Question: {input}
SQL Query:
"""
)

def clean_sql_query(query: str) -> str:
    query = query.strip()
    if query.startswith("```sql"):
        query = query[6:].strip()
    if query.endswith("```"):
        query = query[:-3].strip()
    return query

def run_sql_to_dataframe(query: str) -> pd.DataFrame:
    cleaned = clean_sql_query(query)
    print("SQL code:",cleaned)
    return pd.read_sql_query(cleaned, engine)

sql_chain_input = RunnableParallel({
    "input": RunnablePassthrough(),
    "table_info": lambda _: db.get_table_info(),
    "dialect": lambda _: db.dialect,
})

chain = (
    sql_chain_input
    | sql_prompt
    | llm
    | StrOutputParser()
    | RunnableLambda(run_sql_to_dataframe)   # final output â†’ DataFrame
)

queries = [
    "For each dealer, find how many articles they sell below the average price of that article.",
    "Find the top 2 cheapest dealers per article."
]
for q in queries:
    print("\nUser:", q)
    df = chain.invoke(q)      # DataFrame output
    print(df.to_string(index=False))


User: For each dealer, find how many articles they sell below the average price of that article.
SQL code: SELECT dealer, count(*) FROM shop AS s1 WHERE price < (SELECT avg(price) FROM shop AS s2 WHERE s1.article = s2.article) GROUP BY dealer
dealer  count(*)
     A         1
     B         1
     D         1

User: Find the top 2 cheapest dealers per article.
SQL code: SELECT article, dealer, price FROM (SELECT article, dealer, price, ROW_NUMBER() OVER (PARTITION BY article ORDER BY price ASC) as rn FROM shop) as subquery WHERE rn <= 2
 article dealer  price
       1      A   3.45
       1      B   3.99
       2      A  10.99
       3      D   1.25
       3      B   1.45
       4      D  19.95
