Creating connection to the model and the DB

In [27]:
from langchain_ollama import ChatOllama

# Initialize the LLM with your locally running Ollama instance
llm = ChatOllama(model="llama3.2:latest")

Reading the configuration file to extract the values for the DB connection

In [28]:
import yaml

with open('../config/config.yml', 'r') as file:
    config = yaml.safe_load(file)

mysql_config = config.get('mysql')

username = mysql_config.get('username')
password = mysql_config.get('password')
host = mysql_config.get('host')
port = mysql_config.get('port')
database = mysql_config.get('database')

Establishment of the MySQL connection

In [29]:
from sqlalchemy import create_engine

connection_string = f"mysql+mysqlconnector://{username}:{password}@{host}:{port}/{database}"
engine = create_engine(connection_string)

Creation of the DB connection to work with the Model

In [30]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase(engine)

Checking the connection by running a simple query

In [31]:
db.run("SELECT * FROM employee")

"[(1, 'John', 'Doe', 'john.doe@example.com', Decimal('55000.00')), (2, 'Jane', 'Smith', 'jane.smith@example.com', Decimal('62000.00')), (3, 'Mike', 'Johnson', 'mike.johnson@example.com', Decimal('48000.00')), (4, 'Emily', 'Davis', 'emily.davis@example.com', Decimal('51000.00'))]"

Addition of the chain to create the actual query to DB

In [32]:
from langchain.chains import create_sql_query_chain

sql_chain = create_sql_query_chain(llm, db)

In [33]:
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES. Dont forget to check the tables available"
response = sql_chain.invoke({'question': question})
print(response)

Answer: There are 3 employees.

SQLQuery: SELECT COUNT(*) FROM employee;


Upload of csv files to the DB

In [None]:
import os
import pandas as pd

file_dir_list = os.listdir("../data")
print(file_dir_list)

['cancer.csv', 'diabetes.csv']


In [14]:
for file in file_dir_list:
    full_file_path = os.path.join("/Users/mark/Documents/Research/LLM_MODEL/data", file)
    file_name, file_extension = os.path.splitext(file)
    if file_extension == ".csv":
        df = pd.read_csv(full_file_path)
    elif file_extension == ".xlsx":
        df = pd.read_excel(full_file_path)
    else:
        raise ValueError("The selected file type is not supported")
    df.to_sql(file_name, engine, index=False)
print("All csv files are saved into the sql database.")

All csv files are saved into the sql database.


Checking whether the uplaod of files suceeded

In [15]:
from sqlalchemy import inspect

insp = inspect(engine)
table_names = insp.get_table_names()
print("Available table names in created SQL DB:", table_names)

Available table names in created SQL DB: ['cancer', 'diabetes', 'employee']


Updating the query to make it more efficient for SQL generation

In [34]:
from langchain_core.prompts import (SystemMessagePromptTemplate, 
                                    HumanMessagePromptTemplate,
                                    ChatPromptTemplate)

from langchain_core.output_parsers import StrOutputParser

system = SystemMessagePromptTemplate.from_template("""You are helpful AI assistant who answer user question based on the provided context.""")

prompt = """Answer user question based on the provided context ONLY! If you do not know the answer, just say "I don't know".
            ### Context:
            {context}

            ### Question:
            {question}

            ### Answer:"""

prompt = HumanMessagePromptTemplate.from_template(prompt)

messages = [system, prompt]
template = ChatPromptTemplate(messages)

qna_chain = template | llm | StrOutputParser()

def ask_llm(context, question):
    return qna_chain.invoke({'context': context, 'question': question})

In [35]:
from langchain_core.runnables import chain

@chain
def get_correct_sql_query(input):
    context = input['context']
    question = input['question']

    intruction = """
        Use above context to fetch the correct SQL query for following question
        {}

        Do not enclose query in ```sql and do not write preamble and explanation.
        You MUST return only single SQL query.
    """.format(question)

    response = ask_llm(context=context, question=intruction)

    return response


In [36]:
response = get_correct_sql_query.invoke({'context': response, 'question': question})

In [37]:
db.run(response)

'[(4,)]'

Below is the creation of the final sql chain that outputs the human-like response. However, the model should decide whether to use it as it is supposed to generate interactive images

In [38]:
from langchain_community.tools import QuerySQLDatabaseTool
from langchain_core.runnables import RunnablePassthrough 

execute_query = QuerySQLDatabaseTool(db=db)
sql_query = create_sql_query_chain(llm, db)

final_chain = (
    {'context': sql_query, 'question': RunnablePassthrough()}
    | get_correct_sql_query
    | execute_query
)

In [39]:
question = "how many employees are there? You MUST RETURN ONLY MYSQL QUERIES."

response = final_chain.invoke({'question': question})
print(response)

[(4,)]


Integration of the SQL Agent to generate requests based on the database schema. Additionally, it will reask model to regenerate a sql response in case of failure. model will be provided with the feedback from the DB connection

In [40]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [41]:
tools = toolkit.get_tools()
tools

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x135710410>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x135710410>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x135710410>),
 QuerySQLCheckerTool(description='Use this tool to double check if your 

In [42]:
from langchain_core.messages import SystemMessage

with open('../config/config.yml', 'r') as file:
    config = yaml.safe_load(file)

prefix = config.get('llm_model').get('prefix')

system_message = SystemMessage(content=prefix)

In [43]:
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(llm, tools, state_modifier=system_message, debug=False)
question = "How many employees are there?"
# question = "How many departments are there?"

agent_executor.invoke({"messages": [HumanMessage(content=question)]})

{'messages': [HumanMessage(content='How many employees are there?', additional_kwargs={}, response_metadata={}, id='ecb70cd9-dabe-4b32-a68a-05ef4660b626'),
  AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'llama3.2:latest', 'created_at': '2025-02-16T16:01:21.286044Z', 'done': True, 'done_reason': 'stop', 'total_duration': 9514554458, 'load_duration': 31684792, 'prompt_eval_count': 717, 'prompt_eval_duration': 8120000000, 'eval_count': 19, 'eval_duration': 1356000000, 'message': Message(role='assistant', content='', images=None, tool_calls=None)}, id='run-1e28269a-1ceb-4a4b-8c5b-79e5a921b1f1-0', tool_calls=[{'name': 'sql_db_query_checker', 'args': {'query': 'SELECT COUNT(*) FROM employees'}, 'id': '19e73b3d-639d-405d-a97a-cc81b80a27ea', 'type': 'tool_call'}], usage_metadata={'input_tokens': 717, 'output_tokens': 19, 'total_tokens': 736}),
  ToolMessage(content='SELECT COUNT(*) FROM employees', name='sql_db_query_checker', id='6663a857-23de-40e2-bbb2-f1a3a26e67c7

In [44]:
for s in agent_executor.stream(
    {"messages": [HumanMessage(content=question)]}
):
    print(s)
    print("----")

{'agent': {'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'llama3.2:latest', 'created_at': '2025-02-16T16:02:28.596127Z', 'done': True, 'done_reason': 'stop', 'total_duration': 11644306458, 'load_duration': 34521000, 'prompt_eval_count': 717, 'prompt_eval_duration': 5380000000, 'eval_count': 68, 'eval_duration': 6222000000, 'message': Message(role='assistant', content='', images=None, tool_calls=None)}, id='run-6b6efed4-94ef-4e8c-9a7d-40c9f3caef26-0', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'ea153938-8ac1-4c02-a76f-5a0e77496954', 'type': 'tool_call'}, {'name': 'sql_db_schema', 'args': {'table_names': 'table1'}, 'id': '35948a33-7d51-46f4-aed4-c40a4c422b83', 'type': 'tool_call'}, {'name': 'sql_db_query_checker', 'args': {'query': 'SELECT COUNT(*) FROM table1'}, 'id': '4325018b-6de1-4b6f-bb6f-c8b5f2d07cef', 'type': 'tool_call'}, {'name': 'sql_db_query', 'args': {'query': 'SELECT COUNT(*) FROM table1'}, 'id': '8aa915a4-84ea-4a6e-9a