In [79]:
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import RunnableSequence,RunnableBranch,RunnablePassthrough,RunnableLambda
from pydantic import BaseModel

In [80]:
load_dotenv()

True

In [81]:
model=ChatGroq(model="llama-3.1-8b-instant",temperature=0)

In [82]:
class SQLSchema(BaseModel):
    sql:str
    confidence_score:float
    feedback:str

In [83]:
parser=PydanticOutputParser(pydantic_object=SQLSchema)

In [84]:
sys_instruction="""
You are a SQL Query Expert. Given a natural language, and schema you will write a optimized sql code to query information from the schema.
---------RULES-----------------
1. You must return only JSON object
2. You must only return SELECT commands
3. Return sql code,confidence score,feedback
4. Do not include any introductory text, markdown explanations, or extra commentary.
"""

prompt1 = ChatPromptTemplate.from_messages([
    ("system", sys_instruction + "\n\n{format}"),
    ("human", "Natural_language: {natural_language}")
]).partial(format=parser.get_format_instructions())

In [85]:
chain1=RunnableSequence(prompt1,model,parser)

In [86]:
carry_chain=RunnablePassthrough.assign(sql_output=chain1)

In [87]:
def map_to_prompt2(data):
    sql_out = data["sql_output"]
    return {
        "natural_language": data["natural_language"],
        "sql": sql_out.sql,
        "confidence_score": sql_out.confidence_score,
        "feedback": sql_out.feedback
    }

mapper = RunnableLambda(map_to_prompt2)


In [88]:
sys_instruction="""
You are a SQL Query Expert. Given a natural language,schema,previous sql query and a feedback you will write a optimized sql code to query information from the schema.
---------RULES-----------------
1. You must return only JSON object
2. You must only return SELECT commands
3. Return sql code,confidence score,feedback
4. Do not include any introductory text, markdown explanations, or extra commentary.
"""

prompt2 = ChatPromptTemplate.from_messages([
    ("system", sys_instruction + "\n\n{format}"),
    ("human", "Natural_language: {natural_language}; SQL: {sql}; Feedback: {feedback}")
]).partial(format=parser.get_format_instructions())

In [89]:
chain2=RunnableSequence(prompt2,model,parser)

In [90]:
def need_refinement(output:SQLSchema):
    return chain1.confidence_score<0.6

In [95]:
final_chain=carry_chain | mapper | chain2

In [96]:
result=final_chain.invoke({"natural_language":"From a school database consisting of name,age,grades,gender find the highest marks"})

In [None]:
result.sql

'SELECT MAX(grades) FROM school_database'