Wrapper class for the locally hosted LLM

In [1]:
import requests as rq
import json

class LLM():
    def __init__(self, hostname, port):
        self.url = f"http://{hostname}:{port}/api/v1/generate"
    
    def get_response(self, user_request):
        # Sends a request to the LLM with the complete prompt
        # The preset used for the parameters was Divine Intellect
        response = rq.post(self.url, json.dumps({
            "max_context_length": 2048, 
            "max_length": 120, 
            "prompt": user_request,
            "temperature": 1.31,
            "top_k": 49,
            "top_p": 0.14,
            "stop_sequence": ["</SQLQuery>\n"]}))
        return response.json()["results"][0]["text"]

In [2]:
llm = LLM("localhost", 5001)

Agent class

In [3]:
import sqlalchemy as db
from sqlalchemy.sql import text
import random as rd

class QueriesTranslationAgent():
    def __init__(self, llm, _db, examples_file):
        self.llm = llm
        self.engine = db.create_engine(f"sqlite:///{_db}")
        with open(examples_file, "r") as file:
            self.examples = file.read()
        self.examples = self.examples.split("\n\n")
    
    def build_prompt(self, user_request):
        # For the prompt, two shot prompting was used, along with a template that explained the task
        example1, example2 = rd.sample(self.examples, 2)
        template = """
You are an agent designed to translate questions about a database into SQL queries. You have a table called apple_daily with the following columns:

(Date DATE, Time TIME, Open REAL, High REAL, Low REAL, Close REAL, Volume INTEGER)

Never query for all the columns from a specific table, only ask for the relevant columns given the question.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
You should always think about what to do step by step.
Translate:

{example1}

{example2}

<question>
Question: {user_request}
<question>
<SQLQuery>
"""
        prompt = template.format(example1=example1, example2=example2, user_request=user_request)
        return prompt

    def parse_raw_response(self, response):
        query = response.split("\n")[0].split("SQLQuery: ")[1]
        return query

    def invoke(self, user_request):
        prompt = self.build_prompt(user_request)
        response = self.llm.get_response(prompt)
        query = self.parse_raw_response(response)
        query_result = self.engine.connect().execute(text(query)).fetchall()
        self.engine.dispose()
        return query_result

In [4]:
agent = QueriesTranslationAgent(llm, "apple.db", "examples.txt")
print(agent.invoke("Give me the first 5 values of the table apple_daily"))

[('2023-01-03', '01:00:00', 130.28, 130.9, 124.17, 125.07, 112117471), ('2023-01-04', '01:00:00', 126.89, 128.6557, 125.08, 126.36, 89113633), ('2023-01-05', '01:00:00', 127.13, 127.77, 124.76, 125.02, 80962708), ('2023-01-06', '01:00:00', 126.01, 130.29, 124.89, 129.62, 87754715), ('2023-01-09', '01:00:00', 130.465, 133.41, 129.89, 130.15, 70790813)]


In [5]:
print(agent.invoke("Give me all the rows corresponding to the date 2023-01-04"))

[('2023-01-04', '01:00:00', 126.89, 128.6557, 125.08, 126.36, 89113633)]


In [6]:
print(agent.invoke("Give me all the rows whose date goes between 2023-01-04 and 2023-01-15, including both dates too"))

[('2023-01-04', '01:00:00', 126.89, 128.6557, 125.08, 126.36, 89113633), ('2023-01-05', '01:00:00', 127.13, 127.77, 124.76, 125.02, 80962708), ('2023-01-06', '01:00:00', 126.01, 130.29, 124.89, 129.62, 87754715), ('2023-01-09', '01:00:00', 130.465, 133.41, 129.89, 130.15, 70790813), ('2023-01-10', '01:00:00', 130.26, 131.2636, 128.12, 130.73, 63896155), ('2023-01-11', '01:00:00', 131.25, 133.51, 130.46, 133.49, 69458949), ('2023-01-12', '01:00:00', 133.88, 134.26, 131.44, 133.41, 71379648), ('2023-01-13', '01:00:00', 132.03, 134.92, 131.66, 134.76, 57809719)]


In [5]:
print(agent.invoke("Count how many rows are there in the table apple_daily"))

[(250,)]


In [8]:
print(agent.invoke("Count how many rows in the table are from the year 2023"))

[(0,)]


In [9]:
print(agent.invoke("Count how many rows in the table are from month 03"))
# Here the query performed doesn't even have the correct syntax for working with an specific month

OperationalError: (sqlite3.OperationalError) no such column: Month
[SQL: SELECT COUNT(*) FROM apple_daily WHERE Month = 3; ]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [10]:
print(agent.invoke("Count how many rows in the table are from the year 2023, considering the date format is YYYY-MM-DD"))

[(250,)]


In [11]:
print(agent.invoke("Count how many rows in the table are from the month 03, considering the date format is YYYY-MM-DD"))
# Here the query performed was Select Count(*) From apple_daily Where Date = '2023-01-03'; which is incorrect

[(0,)]


In [14]:
print(agent.invoke("What's the average value of the column Close in the table apple_daily?"))

[(172.54896,)]


In [16]:
print(agent.invoke("What's the average value of the columns Close and Volume in the table apple_daily?"))
# Here the query performed was SELECT AVG(Close) + AVG(Volume) FROM apple_daily; which is incorrect

[(59233213.37296,)]


In [18]:
print(agent.invoke("What's the average value of the column Close and the column Volume in the table apple_daily?"))
# It appears that the agent is not able to handle multiple columns without a very literal separation between them

[(172.54896, 59233040.824)]


Main points of conflict found with the Agent in Baseline:
- Fails to use SQLite syntax for working with dates
- Fails to deeper understand where more than one column should be present in the query