Creating connection to the model and the DB

In [None]:
from langchain_ollama import OllamaLLM

llm = OllamaLLM(model="llama3.2:latest")

Reading the configuration file to extract the values for the DB connection

In [4]:
import yaml

with open('../config/config.yml', 'r') as file:
    config = yaml.safe_load(file)

mysql_config = config.get('mysql')

username = mysql_config.get('username')
password = mysql_config.get('password')
host = mysql_config.get('host')
port = mysql_config.get('port')
database = mysql_config.get('database')

Establishment of the MySQL connection

In [8]:
from sqlalchemy import create_engine

connection_string = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string)

Creation of the DB connection to work with the Model

In [9]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase(engine)

Checking the connection by running a simple query

In [14]:
db.run("SELECT * FROM employee")

"[(1, 'John', 'Doe', 'john.doe@example.com', Decimal('55000.00')), (2, 'Jane', 'Smith', 'jane.smith@example.com', Decimal('62000.00')), (3, 'Mike', 'Johnson', 'mike.johnson@example.com', Decimal('48000.00')), (4, 'Emily', 'Davis', 'emily.davis@example.com', Decimal('51000.00'))]"

Addition of the chain to create the actual query to DB

In [28]:
from langchain.chains import create_sql_query_chain

sql_chain = create_sql_query_chain(llm, db)

In [29]:
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES. keep in mind, that table is called employee"
response = sql_chain.invoke({'question': question})
print(response)

Question: how many employees are there?

SQLQuery:
```sql
SELECT COUNT(*) 
FROM `employee`;
```


Updating the query to make it more efficient for SQL generation

In [30]:
from langchain_core.prompts import (SystemMessagePromptTemplate, 
                                    HumanMessagePromptTemplate,
                                    ChatPromptTemplate)

from langchain_core.output_parsers import StrOutputParser

system = SystemMessagePromptTemplate.from_template("""You are helpful AI assistant who answer user question based on the provided context.""")

prompt = """Answer user question based on the provided context ONLY! If you do not know the answer, just say "I don't know".
            ### Context:
            {context}

            ### Question:
            {question}

            ### Answer:"""

prompt = HumanMessagePromptTemplate.from_template(prompt)

messages = [system, prompt]
template = ChatPromptTemplate(messages)

qna_chain = template | llm | StrOutputParser()

def ask_llm(context, question):
    return qna_chain.invoke({'context': context, 'question': question})

In [31]:
from langchain_core.runnables import chain

@chain
def get_correct_sql_query(input):
    context = input['context']
    question = input['question']

    intruction = """
        Use above context to fetch the correct SQL query for following question
        {}

        Do not enclose query in ```sql and do not write preamble and explanation.
        You MUST return only single SQL query.
    """.format(question)

    response = ask_llm(context=context, question=intruction)

    return response


In [32]:
response = get_correct_sql_query.invoke({'context': response, 'question': question})

In [33]:
db.run(response)

'[(4,)]'