Wrapper class for the locally hosted LLM

In [1]:
import requests as rq
import json

class LLM():
    def __init__(self, hostname, port):
        self.url = f"http://{hostname}:{port}/api/v1/generate"
    
    def get_response(self, user_request):
        # Sends a request to the LLM with the complete prompt
        # The preset used for the parameters was Divine Intellect
        response = rq.post(self.url, json.dumps({
            "max_context_length": 2048, 
            "max_length": 120, 
            "prompt": user_request,
            "temperature": 1.31,
            "top_k": 49,
            "top_p": 0.14,
            "stop_sequence": ["</SQLQuery>\n"]}))
        return response.json()["results"][0]["text"]

In [2]:
llm = LLM("localhost", 5001)

Wrapper class for the Vector Dabatase that holds the context added through RAG 

In [3]:
import chromadb
from sentence_transformers import SentenceTransformer
import regex as re

class ContextVectorDB():
    def __init__(self, collection_name):
        # Initialization of the embedding model
        self.embedding_model = SentenceTransformer('sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
        self.chroma_client = chromadb.Client()
        self.collection = self.chroma_client.create_collection(name=collection_name)

    def parse_string_to_json(self, input_string):
        # Uses Regex to parse the raw input string
        question_pattern = r'<question>\s*Question:\s*(.*?)\s*\n</question>'
        query_pattern = r'<SQLQuery>\s*SQLQuery:\s*(.*?)\s*\n</SQLQuery>'

        question_match = re.search(question_pattern, input_string, re.DOTALL)
        query_match = re.search(query_pattern, input_string, re.DOTALL)

        question = question_match.group(1).strip() if question_match else None
        query = query_match.group(1).strip() if query_match else None

        result = {
            "question": question,
            "query_result": query
        }

        return result

    def load_examples(self, examples):
        examples_parsed = []
        for example in examples:
            examples_parsed.append(self.parse_string_to_json(example))
        self.populate_vectors(examples_parsed)

    def populate_vectors(self, dataset):
        # Iterates over the dataset and adds the embeddings to the ChromaDB collection
        for i, item in enumerate(dataset):
            user_input = item['question']
            query = item['query_result']
            combined_text = f"""
<question>
Question: {user_input}
</question>
<SQLquery>
SQLQuery: {query}
</SQLquery>"""
            embeddings = self.embedding_model.encode(user_input).tolist()
            self.collection.add(embeddings=[embeddings], documents=[combined_text], ids=[f"id_{i}"])

    def search_context(self, user_input, n_results=1):
        # Method to search the ChromaDB collection for relevant context based on a user input
        user_embeddings = self.embedding_model.encode(user_input).tolist()
        return self.collection.query(query_embeddings=user_embeddings, n_results=n_results, include=['documents'])

Agent class

In [4]:
import sqlalchemy as db
from sqlalchemy.sql import text
import random as rd
import csv

class QueriesTranslationAgent():
    def __init__(self, llm, _db, examples_file):
        self.llm = llm
        self.engine = db.create_engine(f"sqlite:///{_db}")
        self.data_retreived = []

        with open(examples_file, "r") as file:
            examples = file.read()
        examples = examples.split("\n\n")
        self.vector_db = ContextVectorDB("user_inputs")
        self.vector_db.load_examples(examples)
    
    def build_prompt(self, user_request):
        # For the prompt, two shot prompting was used, along with a template that explained the task
        examples_raw = self.vector_db.search_context(user_request, n_results=2)
        example1 = examples_raw["documents"][0][0]
        example2 = examples_raw["documents"][0][1]

        template = """
You are an agent designed to translate questions about a database into SQL queries.
You also have a tool that can export the results to a csv file. 
You have a table called apple_daily with the following columns:

(Date DATE, Time TIME, Open REAL, High REAL, Low REAL, Close REAL, Volume INTEGER)

Never query for all the columns from a specific table, only ask for the relevant columns given the question.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the users asks for either results to be exported or a csv file with results, you should call your tool making your answer be "SELECT export data;".
If you think the user wants to export data you shouldn't say anything else other than "SELECT export data;".
You should always think about what to do step by step.
Translate:

{example1}

{example2}

<question>
Question: {user_request}
<question>
<SQLQuery>
SQLQuery: """
        prompt = template.format(example1=example1, example2=example2, user_request=user_request)
        return prompt

    def parse_raw_response(self, response):
        query = response.split("\n")[0].lstrip()
        return query
    
    def execute_query(self, query):
        query_result = self.engine.connect().execute(text(query)).fetchall()
        self.engine.dispose()
        self.data_retreived.append([query_result])
        return query_result

    def export_data(self):
        with open("data.csv", "w") as file:
            writer = csv.writer(file, delimiter=",", quotechar='"', lineterminator="\n")
            for data in self.data_retreived:
                writer.writerows(data)
        self.data_retreived = []
        return "Data exported to file data.csv"

    def invoke(self, user_request):
        prompt = self.build_prompt(user_request)
        response = self.llm.get_response(prompt)
        query = self.parse_raw_response(response)
        if query == "SELECT export data;":
            return self.export_data()
        return self.execute_query(query)

In [5]:
agent = QueriesTranslationAgent(llm, "apple.db", "examples.txt")

In [6]:
print(agent.invoke("Give me all the rows corresponding to the date 2023-01-04"))

[('2023-01-04', '01:00:00', 126.89, 128.6557, 125.08, 126.36, 89113633)]


In [7]:
print(agent.invoke("Count how many rows in the table are from the year 2023"))

[(250,)]


In [8]:
print(agent.invoke("Count how many rows in the table are from month 03"))

[(23,)]


In [9]:
print(agent.invoke("What's the average value of the columns Open and Time in the table apple_daily?"))

[(172.25632000000002, 1.0)]


In [10]:
print(agent.invoke("I want to export results to a csv file"))

Data exported to file data.csv


Sandbox for testing the Agent

In [11]:
output = "Begin!"
while (output != "Data exported to file data.csv"):
    user_request = input("Enter your question: ")
    output = agent.invoke(user_request)
    print(output)

[(172.25632000000002,)]
Data exported to file data.csv
