In [1]:
# %pip install langchain_openai langchain_community langchain pymysql chromadb

In [2]:
import os
import warnings
from sqlalchemy.exc import SAWarning
from langchain_groq import ChatGroq
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.prompts.prompt import PromptTemplate

In [3]:
warnings.simplefilter("ignore", category=SAWarning)

In [4]:
db_user = "root"
db_password = "admin"
db_host = "localhost:3306"
db_name = "sakila"           # "airline_passenger_details"

# db_url = "mysql+pymysql://root:admin@localhost:3306/airline_passenger_details"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
# print(db.dialect)
print(db.get_usable_table_names())
# print(db.table_info)


['actor', 'address', 'category', 'city', 'country', 'customer', 'film', 'film_actor', 'film_category', 'film_text', 'inventory', 'language', 'payment', 'rental', 'staff', 'store']


In [None]:
import os
os.environ['LANGSMITH_TRACING']="true"
os.environ['LANGSMITH_ENDPOINT']="https://api.smith.langchain.com"
os.environ['LANGSMITH_API_KEY']="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"
os.environ['LANGSMITH_PROJECT']="pr-linear-pathway-70"

In [7]:
google_api_key = os.getenv("GOOGLE_API_KEY")
os.environ["GOOGLE_API_KEY"] = google_api_key

In [8]:
import google.generativeai as genai

genai.configure(api_key=google_api_key)

models = genai.list_models()
for model in models:
    print(model.name)


  from .autonotebook import tqdm as notebook_tqdm


models/chat-bison-001
models/text-bison-001
models/embedding-gecko-001
models/gemini-1.0-pro-vision-latest
models/gemini-pro-vision
models/gemini-1.5-pro-latest
models/gemini-1.5-pro-001
models/gemini-1.5-pro-002
models/gemini-1.5-pro
models/gemini-1.5-flash-latest
models/gemini-1.5-flash-001
models/gemini-1.5-flash-001-tuning
models/gemini-1.5-flash
models/gemini-1.5-flash-002
models/gemini-1.5-flash-8b
models/gemini-1.5-flash-8b-001
models/gemini-1.5-flash-8b-latest
models/gemini-1.5-flash-8b-exp-0827
models/gemini-1.5-flash-8b-exp-0924
models/gemini-2.5-pro-exp-03-25
models/gemini-2.5-pro-preview-03-25
models/gemini-2.5-flash-preview-04-17
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-lite-preview-02-05
models/gemini-2.0-flash-lite-preview
models/gemini-2.0-pro-exp
models/gemini-2.0-pro-exp-02-05
models/gemini-exp-1206
m

In [9]:
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0) # gemini-1.5-pro-latest

In [10]:
# llm=ChatGroq(groq_api_key=groq_api_key,model_name="mixtral-8x7b-32768") # Gemma2-9b-It
# llm

In [11]:
# from langchain.chains import create_sql_query_chain
# from langchain_openai import ChatOpenAI

# llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
# generate_query = create_sql_query_chain(llm, db)
# query = generate_query.invoke({"question": "`Sanjay V` where is he going to?"})
# # "what is price of `1968 Ford Mustang`"
# print(query)

# OUTPUT ##
# SELECT `Passenger name`, `Arrival place`
# FROM passengers_details
# WHERE `Passenger name` = 'Sanjay V'
# LIMIT 1;

In [12]:
from langchain.chains.sql_database.prompt import PROMPT

# Copy and modify the default prompt
custom_prompt = PROMPT.model_copy()
custom_prompt.template = custom_prompt.template.replace("SQLResult: ", "").strip()  # Remove 
custom_prompt.template = custom_prompt.template + "\nReturn only the SQL query without any additional explanation or formatting."
custom_prompt.pretty_print()

Given an input question, first create a syntactically correct [33;1m[1;3m{dialect}[0m query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m
Return only the SQL query without 

In [13]:
def extract_sql_query(text):
    try:
        # Remove everything before and including "```sql"
        if "```sql" in text:
            text = text.split("```sql", 1)[1]

        # Remove trailing backticks and surrounding whitespace
        text = text.replace("```", "").strip()

        return text
    except Exception as e:
        return e

In [14]:
from langchain.chains import create_sql_query_chain

# generate_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
generate_query = create_sql_query_chain(llm, db, prompt=custom_prompt)

query = generate_query.invoke({"question": "Get the address of 'ANNA HILL'"})
query = extract_sql_query(query)
query

"SELECT address, district\nFROM address a\nJOIN customer c ON a.address_id = c.address_id\nWHERE c.first_name = 'ANNA' AND c.last_name = 'HILL';"

In [15]:
from langchain_community.tools import QuerySQLDatabaseTool
execute_query = QuerySQLDatabaseTool(db=db)
execute_query.invoke(query)

"[('127 Purnea (Purnia) Manor', 'Piemonte')]"

In [16]:
question = {"question": "list all the film acted by 'JOE SWANK'"}
# question = {"question": "Which actor has generated the highest total revenue across all films they acted in, and what is the average rating of those films?"}

# Step 1: Generate the SQL query
sql_query = generate_query.invoke(question)
q_1 = extract_sql_query(sql_query)
print("Generated SQL Query:", q_1)

# Step 2: Execute the query
final_result = execute_query.invoke(q_1)
print("Final Result:", final_result)

Generated SQL Query: SELECT f.title
FROM film AS f
JOIN film_actor AS fa ON f.film_id = fa.film_id
JOIN actor AS a ON fa.actor_id = a.actor_id
WHERE a.first_name = 'JOE' AND a.last_name = 'SWANK'
LIMIT 5;
Final Result: [('ANYTHING SAVANNAH',), ('BIRCH ANTITRUST',), ('CHOCOLAT HARRY',), ('CHOCOLATE DUCK',), ('CROOKED FROGMEN',)]


In [18]:
# chain = generate_query | execute_query
# chain.invoke({"question": "list all the film acted by 'JOE SWANK'"})
# chain.invoke({"question": "which actor acted more than 5 films and all must be Englist?'"})

In [21]:
from operator import itemgetter
from pprint import pprint
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

# Truncation function (approx. 1 token ≈ 4 characters)
def truncate_text_approx(text: str, max_chars: int = 12000):
    return text[:max_chars] if len(text) > max_chars else text

# Truncation wrapper for LangChain
truncate_question = RunnableLambda(lambda x: {
    **x,
    "question": truncate_text_approx(x["question"], 12000)
})

# SQL cleanup
def clean_sql_query(query: str) -> str:
    if "```sql" in query:
        query = query.split("```sql", 1)[1]
    return query.replace("```", "").strip()

clean_query = RunnableLambda(lambda x: clean_sql_query(x))

# Answer generation
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer:"""
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

# Final chain with truncation + SQL cleanup
chain = (
    truncate_question
    | RunnablePassthrough.assign(
        query=generate_query | clean_query
    ).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

# Example usage
question = "Which actor acted in more than 5 films and all must be English?"
output = chain.invoke({"question": question})
print(output)


Penelope Guiness, Nick Wahlberg, Ed Chase, Jennifer Davis, and Johnny Lollobrigida each acted in more than 5 English films.


### Few-shot learning can significantly improve the model's
### Example formate

In [22]:
examples = [
    {
        "input":"list all the film acted by 'JOE SWANK'",
        "query":"SELECT f.title FROM film AS f JOIN film_actor AS fa ON f.film_id = fa.film_id JOIN actor AS a ON fa.actor_id = a.actor_id WHERE a.first_name = 'JOE' AND a.last_name = 'SWANK' LIMIT 5;"
    },
    {
        "input": "Which actor has generated the highest total revenue across all films they acted in, and what is the average rating of those films?",
        "query": "SELECT a.first_name, a.last_name, SUM(p.amount) AS total_revenue, AVG(f.rating) AS average_rating FROM actor a JOIN film_actor fa ON a.actor_id = fa.actor_id JOIN film f ON fa.film_id = f.film_id JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN payment p ON r.rental_id = p.rental_id GROUP BY a.actor_id ORDER BY total_revenue DESC LIMIT 1;"
    },
    {
        "input": "Find all customers who have rented films from both the 'Horror' and 'Comedy' categories but never rented a 'Documentary' film.",
        "query": "SELECT c.first_name, c.last_name FROM customer AS c JOIN rental AS r ON c.customer_id = r.customer_id JOIN inventory AS i ON r.inventory_id = i.inventory_id JOIN film AS f ON i.film_id = f.film_id JOIN film_category AS fc ON f.film_id = fc.film_id JOIN category AS cat ON fc.category_id = cat.category_id WHERE cat.name IN ('Horror', 'Comedy') GROUP BY c.customer_id HAVING COUNT(DISTINCT CASE WHEN cat.name = 'Horror' THEN cat.category_id END) > 0 AND COUNT(DISTINCT CASE WHEN cat.name = 'Comedy' THEN cat.category_id END) > 0 AND c.customer_id NOT IN (SELECT c2.customer_id FROM customer AS c2 JOIN rental AS r2 ON c2.customer_id = r2.customer_id JOIN inventory AS i2 ON r2.inventory_id = i2.inventory_id JOIN film AS f2 ON i2.film_id = f2.film_id JOIN film_category AS fc2 ON f2.film_id = fc2.film_id JOIN category AS cat2 ON fc2.category_id = cat2.category_id WHERE cat2.name = 'Documentary') LIMIT 5;"
    },
    {
        "input": "which actor acted more than 5 films and all must be Englist?",
        "query": "SELECT a.first_name, a.last_name FROM actor a JOIN film_actor fa ON a.actor_id = fa.actor_id JOIN film f ON fa.film_id = f.film_id JOIN language l ON f.language_id = l.language_id WHERE l.name = 'English' GROUP BY a.actor_id HAVING COUNT(f.film_id) > 5 LIMIT 5;"
    },
    {
        "input": "List the top 5 customers who have spent the most money, along with their total amount spent, number of rentals, and their country.",
        "query": "SELECT c.first_name, c.last_name, SUM(p.amount) AS total_amount_spent, COUNT(p.payment_id) AS num_rentals, co.country FROM customer AS c JOIN address AS a ON c.address_id = a.address_id JOIN city AS ci ON a.city_id = ci.city_id JOIN country AS co ON ci.country_id = co.country_id JOIN payment AS p ON c.customer_id = p.customer_id GROUP BY c.customer_id ORDER BY total_amount_spent DESC LIMIT 5;"
    }
]


In [23]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="How many films are there?"))


Human: list all the film acted by 'JOE SWANK'
SQLQuery:
AI: SELECT f.title FROM film AS f JOIN film_actor AS fa ON f.film_id = fa.film_id JOIN actor AS a ON fa.actor_id = a.actor_id WHERE a.first_name = 'JOE' AND a.last_name = 'SWANK' LIMIT 5;
Human: Which actor has generated the highest total revenue across all films they acted in, and what is the average rating of those films?
SQLQuery:
AI: SELECT a.first_name, a.last_name, SUM(p.amount) AS total_revenue, AVG(f.rating) AS average_rating FROM actor a JOIN film_actor fa ON a.actor_id = fa.actor_id JOIN film f ON fa.film_id = f.film_id JOIN inventory i ON f.film_id = i.film_id JOIN rental r ON i.inventory_id = r.inventory_id JOIN payment p ON r.rental_id = p.rental_id GROUP BY a.actor_id ORDER BY total_revenue DESC LIMIT 1;
Human: Find all customers who have rented films from both the 'Horror' and 'Comedy' categories but never rented a 'Documentary' film.
SQLQuery:
AI: SELECT c.first_name, c.last_name FROM customer AS c JOIN rental AS r

### Implementing Dynamic Few-Shot Selection

In [24]:
from langchain_google_genai import GoogleGenerativeAIEmbeddings

embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")


In [25]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
# from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)
print(few_shot_prompt.format(input="How many products are there?"))

  vectorstore = Chroma()


Human: which actor acted more than 5 films and all must be Englist?
SQLQuery:
AI: SELECT a.first_name, a.last_name FROM actor a JOIN film_actor fa ON a.actor_id = fa.actor_id JOIN film f ON fa.film_id = f.film_id JOIN language l ON f.language_id = l.language_id WHERE l.name = 'English' GROUP BY a.actor_id HAVING COUNT(f.film_id) > 5 LIMIT 5;
Human: List the top 5 customers who have spent the most money, along with their total amount spent, number of rentals, and their country.
SQLQuery:
AI: SELECT c.first_name, c.last_name, SUM(p.amount) AS total_amount_spent, COUNT(p.payment_id) AS num_rentals, co.country FROM customer AS c JOIN address AS a ON c.address_id = a.address_id JOIN city AS ci ON a.city_id = ci.city_id JOIN country AS co ON ci.country_id = co.country_id JOIN payment AS p ON c.customer_id = p.customer_id GROUP BY c.customer_id ORDER BY total_amount_spent DESC LIMIT 5;


In [26]:
print(chain.input_schema)

<class 'langchain_core.runnables.base.RunnableLambdaInput'>


In [27]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)
# sql preprocessing function
def clean_sql_query(query: str) -> str:
    if "```sql" in query:
        query = query.split("```sql", 1)[1]
    return query.replace("```", "").strip()

clean_query = RunnableLambda(lambda x: clean_sql_query(x))

# print('TEST_1 >> ', final_prompt.format(input="How many products are there?",table_info="some table info"))
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
RunnablePassthrough.assign(query=generate_query | clean_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "how many englist movie rated PG-13?"})


'There are 223 English movies rated PG-13.'