In [25]:

from langchain.llms import OpenAI
from langchain.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.chat_models import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI


def connectDatabase(username, port, host, password, database):
    mysql_uri = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
    db = SQLDatabase.from_uri(mysql_uri)
    return db


def runQuery(query,db):
    return db.run(query) if db else "Please connect to database"


def getDatabaseSchema(db):
    return db.get_table_info() if db else "Please connect to database"


llm = ChatGoogleGenerativeAI(model="gemini-pro",api_key="AIzaSyDKAeomvp2rp8ICJ7IF0z8rTcZkDih8mog")




In [26]:
def getResponseForQueryResult(question, query, result, db):
    template2 = """below is the schema of MYSQL database, read the schema carefully about the table and column names of each table.
    Also look into the conversation if available
    Finally write a response in natural language by looking into the conversation and result.

    {schema}

    Here are some example for you:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album;
    Result : [(34,)]
    Response: There are 34 albums in the database.

    question: how many users we have in database
    SQL query: SELECT COUNT(*) FROM customer;
    Result : [(59,)]
    Response: There are 59 amazing users in the database.

    question: how many users above are from india we have in database
    SQL query: SELECT COUNT(*) FROM customer WHERE country=india;
    Result : [(4,)]
    Response: There are 4 amazing users in the database.

    your turn to write response in natural language from the given result :
    question: {question}
    SQL query : {query}
    Result : {result}
    Response:
    """

    prompt2 = ChatPromptTemplate.from_template(template2)
    chain2 = prompt2 | llm

    response = chain2.invoke({
        "question": question,
        "schema": getDatabaseSchema(db),
        "query": query,
        "result": result
    })

    return response.content

In [102]:
def getQueryFromLLM(question,db):
    template = """below is the schema of MYSQL database, read the schema carefully about the table and column names. Also take care of table or column name case sensitivity.
    Finally answer user's question in the form of SQL query.

    {schema}

    please only provide the SQL query and nothing else

    for example:
    question: how many albums we have in database
    SQL query: SELECT COUNT(*) FROM album
    question: how many customers are from Brazil in the database ?
    SQL query: SELECT COUNT(*) FROM customer WHERE country=Brazil

    your turn :
    question: {question}
    SQL query :
    please only provide the SQL query and nothing else
    """

    prompt = ChatPromptTemplate.from_template(template)
    chain = prompt | llm

    response = chain.invoke({
        "question": question,
        "schema": getDatabaseSchema(db)
    })
    return response.content
    # return response

In [100]:
def validate_query(input,question,db):
    template = """I am giving you output of one model, actully i expect the output as sql query, but sometime model giving extra quotes or unexpected words in the ouput,
    so you please currect the output, give exact query.below giving schema of database

    {schema}

    and here i am giving the curresponding question of the query,

    {question}

    please only provide the SQL query and nothing else

    for example:
    input: SELECT COUNT(*) FROM t_shirts
    SQL query: SELECT COUNT(*) FROM t_shirts;

    input: SQL query :\nSELECT COUNT(*) FROM t_shirts
    SQL query: SELECT COUNT(*) FROM t_shirts;

    input: ```sql\nSELECT COUNT(*) FROM t_shirts;\n```
    SQL query: SELECT COUNT(*) FROM t_shirts;

    input: SQL query :\n```sql\nSELECT COUNT(*) FROM t_shirts;\n```
    SQL query: SELECT COUNT(*) FROM t_shirts;

    
    your turn :
    input: {input}
    SQL query: 
    please only provide the SQL query and nothing else
    """

    prompt = ChatPromptTemplate.from_template(template)
    chain = prompt | llm

    response = chain.invoke({
        "input": input,
        "question":question,
        "schema": getDatabaseSchema(db)
    })
    # return response.content
    return response

In [None]:
def retry(question,db):
    try:
        query = getQueryFromLLM(question,db)
        print(query,'query')

        # validate_query
        # query = validate_query(query,question,db)
        # print(query,'query2')

        result = runQuery(query, db)
        print(result)
        return query,result
        
    except:
        return retry(question,db)

In [120]:
db=connectDatabase(username='root', port='3306', host='localhost', password='Atk%408522', database='atliq_tshirts')

question='give me t-shirt and brand which have colour black'

query,result=retry(question,db)

response = getResponseForQueryResult(question, query, result, db)

print (response)

```sql
SELECT
  brand,
  color
FROM t_shirts
WHERE
  color = "Black";
``` query
```sql
SELECT
  t_shirts.brand,
  t_shirts.color
FROM t_shirts
WHERE
  t_shirts.color = 'Black';
``` query
```sql
SELECT
  T1.t_shirt_id,
  T1.brand
FROM T_SHIRTS AS T1
JOIN DISCOUNTS AS T2
  ON T1.T_SHIRT_ID = T2.T_SHIRT_ID
WHERE
  T1.color = 'Black';
``` query
```sql
SELECT
  t_shirts.brand,
  t_shirts.color
FROM t_shirts
WHERE
  t_shirts.color = 'Black';
``` query
SELECT
    `t`.`brand`,
    `t`.`color`
FROM
    `t_shirts` AS `t`
WHERE
    `t`.`color` = 'Black'; query
There are 14 t-shirts with color black, which are from 5 different brands. The brands are Van Huesen, Levi, Nike and Adidas.
