Loading Ollama model

In [1]:
from langchain_ollama import ChatOllama
llm = ChatOllama(model="llama3.2:latest")

Establishing connection to the Database

In [2]:
import yaml
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

with open('../config/config.yml', 'r') as file:
    config = yaml.safe_load(file)

mysql_config = config.get('mysql')

username = mysql_config.get('username')
password = mysql_config.get('password')
host = mysql_config.get('host')
port = mysql_config.get('port')
database = mysql_config.get('database')

connection_string = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string)

db = SQLDatabase(engine)

Checking the connection with a database

In [3]:
db.run("SELECT * FROM employee")

"[(1, 'John', 'Doe', 'john.doe@example.com', Decimal('55000.00')), (2, 'Jane', 'Smith', 'jane.smith@example.com', Decimal('62000.00')), (3, 'Mike', 'Johnson', 'mike.johnson@example.com', Decimal('48000.00')), (4, 'Emily', 'Davis', 'emily.davis@example.com', Decimal('51000.00'))]"

In [15]:
from langchain.chains import create_sql_query_chain

sql_chain = create_sql_query_chain(llm, db)
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES."
response = sql_chain.invoke({"question": question})

Telling model how to extract the query from the previous response (need to parse the string)

In [16]:
from langchain_core.prompts import (SystemMessagePromptTemplate, 
                                    HumanMessagePromptTemplate,
                                    ChatPromptTemplate)
from langchain_core.output_parsers import StrOutputParser

system = SystemMessagePromptTemplate.from_template("""You are helpful AI assistant who answer user question based on the provided context.""")

prompt = """Answer user question based on the provided context ONLY! If you do not know the answer, just say "I don't know".
            ### Context:
            {context}

            ### Question:
            {question}

            ### Answer:"""

prompt = HumanMessagePromptTemplate.from_template(prompt)

messages = [system, prompt]
template = ChatPromptTemplate(messages)

qna_chain = template | llm | StrOutputParser()

def ask_llm(context, question):
    return qna_chain.invoke({'context': context, 'question': question})

In [17]:
from langchain_core.runnables import chain

@chain
def get_correct_sql_query(input):
    context = input['context']
    question = input['question']

    intruction = """
        Use above context to fetch the correct SQL query for following question
        {}

        Do not enclose query in ```sql and do not write preamble and explanation.
        You MUST return only single SQL query.
    """.format(question)

    response = ask_llm(context=context, question=intruction)

    return response

In [18]:
response = get_correct_sql_query.invoke({'context': response, 'question': question})
db.run(response)

'[(4,)]'

Building the chain that is capable of extracting the right query

In [19]:
from langchain_community.tools import QuerySQLDatabaseTool
from langchain_core.runnables import RunnablePassthrough 

execute_query = QuerySQLDatabaseTool(db=db)
sql_query = create_sql_query_chain(llm, db)

final_chain = (
    {'context': sql_query, 'question': RunnablePassthrough()}
    | get_correct_sql_query
    | execute_query
)

In [24]:
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES."

response = final_chain.invoke({'question': question})
print(response)

[(4,)]


Building an agent

Loading the schemas of current databases

In [25]:
from sqlalchemy import inspect

def get_db_schema():
    inspector = inspect(engine)

    tables_info = {}
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        tables_info[table_name] = [col["name"] for col in columns]

    return tables_info

In [26]:
def generate_system_message():
    schema = get_db_schema()
    tables_info = "\n".join(
        f"- {table}: {', '.join(columns)}" for table, columns in schema.items()
    )

    return f"""
    You are an expert in querying SQL databases. The database contains the following tables and columns:
    {tables_info}

    Always use the exact table names provided above. Do not assume or pluralize table names.
    """

In [None]:
from langgraph.prebuilt import create_react_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

system_message = generate_system_message()

agent_executor = create_react_agent(
    llm,
    tools,
    state_modifier=system_message,
    debug=False
)

In [28]:
from langchain_core.messages import HumanMessage

question = "How many employees are there?"

agent_executor.invoke({"messages": [HumanMessage(content=question)]})

{'messages': [HumanMessage(content='How many employees are there?', additional_kwargs={}, response_metadata={}, id='82dbe260-b8ac-4d42-9382-a0ab8a7eec5a'),
  AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'llama3.2:latest', 'created_at': '2025-02-16T18:05:40.536513Z', 'done': True, 'done_reason': 'stop', 'total_duration': 18233652542, 'load_duration': 1357232125, 'prompt_eval_count': 629, 'prompt_eval_duration': 11234000000, 'eval_count': 60, 'eval_duration': 5109000000, 'message': Message(role='assistant', content='', images=None, tool_calls=None)}, id='run-fd64fc35-43a6-45df-b76c-12a85ccb4b74-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': '1bcc4d75-ec24-48d9-93a8-6afedd60f0da', 'type': 'tool_call'}, {'name': 'sql_db_query_checker', 'args': {'query': 'SELECT COUNT(*) FROM employee'}, 'id': '6392db48-268e-4cc1-bbd5-85ddcb6b3484', 'type': 'tool_call'}, {'name': 'sql_db_query', 'args': {'query': 'SELECT COUNT(*) FROM employee'}, 'id': 'd5990a32-