Loading Ollama model

In [1]:
from langchain_ollama import ChatOllama
llm = ChatOllama(model="llama3.1:latest")

Establishing connection to the Database

In [2]:
import yaml
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase

with open('../config/config.yml', 'r') as file:
    config = yaml.safe_load(file)

mysql_config = config.get('mysql')

username = mysql_config.get('username')
password = mysql_config.get('password')
host = mysql_config.get('host')
port = mysql_config.get('port')
database = mysql_config.get('database')

connection_string = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string)

db = SQLDatabase(engine)

Checking the connection with a database

In [3]:
db.run("SELECT * FROM employee")

"[(1, 'John', 'Doe', 'john.doe@example.com', Decimal('55000.00')), (2, 'Jane', 'Smith', 'jane.smith@example.com', Decimal('62000.00')), (3, 'Mike', 'Johnson', 'mike.johnson@example.com', Decimal('48000.00')), (4, 'Emily', 'Davis', 'emily.davis@example.com', Decimal('51000.00'))]"

In [4]:
from langchain.chains import create_sql_query_chain

sql_chain = create_sql_query_chain(llm, db)
question = "what is the average age for those who were diagnosed with diabete? You MUST RETURN ONLY MYSQL QUERIES."
response = sql_chain.invoke({"question": question})

KeyboardInterrupt: 

Telling model how to extract the query from the previous response (need to parse the string)

In [5]:
from langchain_core.prompts import (SystemMessagePromptTemplate, 
                                    HumanMessagePromptTemplate,
                                    ChatPromptTemplate)
from langchain_core.output_parsers import StrOutputParser

system = SystemMessagePromptTemplate.from_template("""You are helpful AI assistant who answer user question based on the provided context.""")

prompt = """Answer user question based on the provided context ONLY! If you do not know the answer, just say "I don't know".
            ### Context:
            {context}

            ### Question:
            {question}

            ### Answer:"""

prompt = HumanMessagePromptTemplate.from_template(prompt)

messages = [system, prompt]
template = ChatPromptTemplate(messages)

qna_chain = template | llm | StrOutputParser()

def ask_llm(context, question):
    return qna_chain.invoke({'context': context, 'question': question})

In [47]:
from langchain_core.runnables import chain

@chain
def get_correct_sql_query(input):
    context = input['context']
    question = input['question']

    intruction = """
        Use above context to fetch the correct SQL query for following question
        {}

        Do not enclose query in ```sql and do not write preamble and explanation.
        You MUST return only single SQL query.
    """.format(question)

    response = ask_llm(context=context, question=intruction)

    return response

In [48]:
response = get_correct_sql_query.invoke({'context': response, 'question': question})
db.run(response)

"[(Decimal('37.0672'),)]"

Building the chain that is capable of extracting the right query

In [38]:
from langchain_community.tools import QuerySQLDatabaseTool
from langchain_core.runnables import RunnablePassthrough 

execute_query = QuerySQLDatabaseTool(db=db)
sql_query = create_sql_query_chain(llm, db)

final_chain = (
    {'context': sql_query, 'question': RunnablePassthrough()}
    | get_correct_sql_query
    | execute_query
)

In [39]:
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES."

response = final_chain.invoke({'question': question})
print(response)

[(4,)]


Building an agent

Loading the schemas of current databases

In [6]:
from sqlalchemy import inspect

def get_db_schema():
    inspector = inspect(engine)

    tables_info = {}
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        tables_info[table_name] = [col["name"] for col in columns]

    return tables_info

In [7]:
def generate_system_message():
    schema = get_db_schema()
    tables_info = "\n".join(
        f"- {table}: {', '.join(columns)}" for table, columns in schema.items()
    )

    return f"""
    You are an expert in querying SQL databases. The database contains the following tables and columns:
    {tables_info}

    Always use the exact table names provided above. Do not assume or pluralize table names.
    """

Creation of the agent from scratch

Create a tool that will try to fix the sql query based on the error message from the DB

In [8]:
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
from typing import Any

def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

Extraction of the tools available for my model and DB

In [9]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

list_tables_tool.invoke("")

'cancer, diabetes, employee'

In [10]:
print(get_schema_tool.invoke("employee"))


CREATE TABLE employee (
	id INTEGER NOT NULL AUTO_INCREMENT, 
	first_name VARCHAR(50) NOT NULL, 
	last_name VARCHAR(50) NOT NULL, 
	email VARCHAR(100) NOT NULL, 
	salary DECIMAL(10, 2) NOT NULL, 
	PRIMARY KEY (id)
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4

/*
3 rows from employee table:
id	first_name	last_name	email	salary
1	John	Doe	john.doe@example.com	55000.00
2	Jane	Smith	jane.smith@example.com	62000.00
3	Mike	Johnson	mike.johnson@example.com	48000.00
*/


Usage of the prebuilt agent

In [30]:
from langgraph.prebuilt import create_react_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()

system_message = generate_system_message()

agent_executor = create_react_agent(
    llm,
    tools,
    state_modifier=system_message,
    debug=False
)

In [31]:
from langchain_core.messages import HumanMessage

question = "what is the average age for those who were diagnosed with diabete?"

agent_executor.invoke({"messages": [HumanMessage(content=question)]})

KeyboardInterrupt: 