## Setup Connection to database and model

In [114]:

from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain_community.llms import HuggingFaceEndpoint
import sqlite3
import os
import requests
from urllib.parse import urlparse

os.environ['HUGGINGFACEHUB_API_TOKEN'] = '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'

# db = SQLDatabase.from_uri("vanna.ai/Chinook.sqlite")
url = 'https://vanna.ai/Chinook.sqlite'

path = os.path.basename(urlparse(url).path)

# Download the database if it doesn't exist
if not os.path.exists(url):
    response = requests.get(url)
    response.raise_for_status()  # Check that the request was successful
    with open(path, "wb") as f:
        f.write(response.content)
    url = path

conn = sqlite3.connect(url, check_same_thread=False)
cursor = conn.cursor()

repo_id = "mistralai/Mistral-7B-v0.1"

llm = HuggingFaceEndpoint(
    repo_id=repo_id, max_length=128, temperature=0.5, token=os.environ['HUGGINGFACEHUB_API_TOKEN']
)




                    max_length was transferred to model_kwargs.
                    Please make sure that max_length is what you intended.
                    token was transferred to model_kwargs.
                    Please make sure that token is what you intended.


Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /Users/garvitmathur/.cache/huggingface/token
Login successful


## Extract Database Metadata Chunks

In [115]:


def table_info_chunks(c):
    '''
    Returns table level chunks of a structured string containing information about all tables in the database.
    
    Parameters:
    c : cursor object
    
    Returns:
    chunks : List
        List of a table level Structured string containing table name, column names, their types, primary key, and constraints.
    '''

    # Query the schema information from sqlite_master table
    c.execute("SELECT name, sql FROM sqlite_master WHERE type='table'")
    tables = c.fetchall()
    chunks = []
    # Iterate through each table
    for table in tables:
        table_name = table[0]
        sql_statement = table[1]
        schema_string = ""
        # Check if the table has SQL statement
        if sql_statement:
            schema_string += f"Table: {table_name}\n"
            columns_info = sql_statement[sql_statement.index('(') + 1:sql_statement.rindex(')')].split(',')

            # Iterate through each column
            for column_info in columns_info:
                # Extract column name and data type
                column_info_parts = column_info.strip().split()
                column_name = column_info_parts[0]
                column_type = ' '.join(column_info_parts[1:])

                # Append column information to schema string
                schema_string += f"\tColumn: {column_name}, Type: {column_type}\n"

            # Check for primary key
            if 'PRIMARY KEY' in sql_statement:
                primary_key = sql_statement[sql_statement.index('PRIMARY KEY'):]
                schema_string += f"\tPrimary Key: {primary_key}\n"

            # Check for constraints
            constraints = [part.strip() for part in sql_statement.split('\n') if 'CONSTRAINT' in part]
            for constraint in constraints:
                schema_string += f"\tConstraint: {constraint}\n"

            schema_string += "\n"
            chunks.append(schema_string)

    return chunks

In [116]:
table_chunks = table_info_chunks(cursor)
print(table_chunks[0])


Table: Album
	Column: [AlbumId], Type: INTEGER NOT NULL
	Column: [Title], Type: NVARCHAR(160) NOT NULL
	Column: [ArtistId], Type: INTEGER NOT NULL
	Column: CONSTRAINT, Type: [PK_Album] PRIMARY KEY ([AlbumId])
	Column: FOREIGN, Type: KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) ON DELETE NO ACTION ON UPDATE NO ACTION
	Primary Key: PRIMARY KEY  ([AlbumId]),
    FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) 
		ON DELETE NO ACTION ON UPDATE NO ACTION
)
	Constraint: CONSTRAINT [PK_Album] PRIMARY KEY  ([AlbumId]),




## Extract SQL Query Chunks

In [117]:
sql_query = '''    -- 1. Provide a query showing Customers (just their full names, customer ID and country) who are not in the US.
select customerid, firstname, lastname, country
from customer
where not country = 'USA';

-- 2. Provide a query only showing the Customers from Brazil.
select * from customer
where country = 'Brazil';

-- 3. Provide a query showing the Invoices of customers who are from Brazil. The resultant table should show the customer's full name, Invoice ID, Date of the invoice and billing country.
select c.firstname, c.lastname, i.invoiceid, i.invoicedate, i.billingcountry
from customer as c, invoice as i
where c.country = 'Brazil' and
c.customerid = i.customerid;

-- 4. Provide a query showing only the Employees who are Sales Agents.
select * from employee
where employee.title = 'Sales Support Agent';

-- 5. Provide a query showing a unique list of billing countries from the Invoice table.
select distinct billingcountry from invoice;

-- 6. Provide a query showing the invoices of customers who are from Brazil.
select *
from customer as c, invoice as i
where c.country = 'Brazil' and
c.customerid = i.customerid;

-- 7. Provide a query that shows the invoices associated with each sales agent. The resultant table should include the Sales Agent's full name.
select e.firstname, e.lastname, i.invoiceid, i.customerid, i.invoicedate, i.billingaddress, i.billingcountry, i.billingpostalcode, i.total
from customer as c, invoice as i
on c.customerid = i.customerid
join employee as e
on e.employeeid = c.supportrepid
order by e.employeeid;

-- 8. Provide a query that shows the Invoice Total, Customer name, Country and Sale Agent name for all invoices and customers.
select e.firstname as 'employee first', e.lastname as 'employee last', c.firstname as 'customer first', c.lastname as 'customer last', c.country, i.total
from employee as e
	join customer as c on e.employeeid = c.supportrepid
	join invoice as i on c.customerid = i.customerid

-- 9. How many Invoices were there in 2009 and 2011? What are the respective total sales for each of those years?
select count(i.invoiceid), sum(i.total)
from invoice as i
where i.invoicedate between datetime('2011-01-01 00:00:00') and datetime('2011-12-31 00:00:00');

select count(i.invoiceid), sum(i.total)
from invoice as i
where i.invoicedate between datetime('2009-01-01 00:00:00') and datetime('2009-12-31 00:00:00');

-- 10. Looking at the InvoiceLine table, provide a query that COUNTs the number of line items for Invoice ID 37.
select count(i.invoicelineid)
from invoiceline as i
where i.invoiceid = 37

-- 11. Looking at the InvoiceLine table, provide a query that COUNTs the number of line items for each Invoice. HINT: [GROUP BY](http://www.sqlite.org/lang_select.html#resultset)
select invoiceid, count(invoicelineid)
from invoiceline
group by invoiceid

-- 12. Provide a query that includes the track name with each invoice line item.
select i.*, t.name
from invoiceline as i, track as t
on i.trackid = t.trackid

-- 13. Provide a query that includes the purchased track name AND artist name with each invoice line item.
select i.*, t.name as 'track', ar.name as 'artist'
from invoiceline as i
	join track as t on i.trackid = t.trackid
	join album as al on al.albumid = t.albumid
	join artist as ar on ar.artistid = al.artistid

-- 14. Provide a query that shows the # of invoices per country. HINT: [GROUP BY](http://www.sqlite.org/lang_select.html#resultset)
select billingcountry, count(billingcountry) as '# of invoices'
from invoice
group by billingcountry

-- 15. Provide a query that shows the total number of tracks in each playlist. The Playlist name should be include on the resultant table.
select *, count(trackid) as '# of tracks'
from playlisttrack, playlist
on playlisttrack.playlistid = playlist.playlistid
group by playlist.playlistid

-- 16. Provide a query that shows all the Tracks, but displays no IDs. The resultant table should include the Album name, Media type and Genre.
select t.name as 'track', t.composer, t.milliseconds, t.bytes, t.unitprice, a.title as 'album', g.name as 'genre', m.name as 'media type'
from track as t
	join album as a on a.albumid = t.albumid
	join genre as g on g.genreid = t.genreid
	join mediatype as m on m.mediatypeid = t.mediatypeid

-- 17. Provide a query that shows all Invoices but includes the # of invoice line items.
select invoice.*, count(invoiceline.invoicelineid) as '# of line items'
from invoice, invoiceline
on invoice.invoiceid = invoiceline.invoiceid
group by invoice.invoiceid

-- 18. Provide a query that shows total sales made by each sales agent.
select e.*, count(i.invoiceid) as 'Total Number of Sales'
from employee as e
	join customer as c on e.employeeid = c.supportrepid
	join invoice as i on i.customerid = c.customerid
group by e.employeeid

-- 19. Which sales agent made the most in sales in 2009?
select *, max(total) from
(select e.*, sum(total) as 'Total'
from employee as e
	join customer as c on e.employeeid = c.supportrepid
	join invoice as i on i.customerid = c.customerid
where i.invoicedate between '2009-01-00' and '2009-12-31'
group by e.employeeid)


-- 20. Which sales agent made the most in sales in 2010?
select *, max(total) from
(select e.*, sum(total) as 'Total'
from employee as e
	join customer as c on e.employeeid = c.supportrepid
	join invoice as i on i.customerid = c.customerid
where i.invoicedate between '2010-01-00' and '2010-12-31'
group by e.employeeid)

-- 21. Which sales agent made the most in sales over all?
select *, max(total) from
(select e.*, sum(total) as 'Total'
from employee as e
	join customer as c on e.employeeid = c.supportrepid
	join invoice as i on i.customerid = c.customerid
group by e.employeeid)

-- 22. Provide a query that shows the # of customers assigned to each sales agent.
select e.*, count(c.customerid) as 'TotalCustomers'
from employee as e
	join customer as c on e.employeeid = c.supportrepid
group by e.employeeid

-- 23. Provide a query that shows the total sales per country. Which country's customers spent the most?
select i.billingcountry, sum(total) as 'TotalSales'
from invoice as i
group by billingcountry
order by totalsales desc

-- 24. Provide a query that shows the most purchased track of 2013.
select *, count(t.trackid) as count
from invoiceline as il
	join invoice as i on i.invoiceid = il.invoiceid
	join track as t on t.trackid = il.trackid
where i.invoicedate between '2013-01-01' and '2013-12-31'
group by t.trackid
order by count desc'''

In [118]:
from langchain.text_splitter import CharacterTextSplitter
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = text_splitter.split_text(sql_query)

In [119]:
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate

In [144]:
print(chunks[0])

-- 1. Provide a query showing Customers (just their full names, customer ID and country) who are not in the US.
select customerid, firstname, lastname, country
from customer
where not country = 'USA';

-- 2. Provide a query only showing the Customers from Brazil.
select * from customer
where country = 'Brazil';


In [132]:
vectorstore = FAISS.from_texts(table_chunks, embedding=HuggingFaceEmbeddings())
vectorstore_query = FAISS.from_texts(chunks, embedding=HuggingFaceEmbeddings())

In [133]:
res = vectorstore.similarity_search("List all the tracks with there corresponding genres")
res2 = vectorstore_query.similarity_search("Provide a query that shows the Invoice Total, Customer name, Country and Sale Agent name for all invoices and customers.")

In [136]:
print(res[0].page_content)

Table: Genre
	Column: [GenreId], Type: INTEGER NOT NULL
	Column: [Name], Type: NVARCHAR(120)
	Column: CONSTRAINT, Type: [PK_Genre] PRIMARY KEY ([GenreId])
	Primary Key: PRIMARY KEY  ([GenreId])
)
	Constraint: CONSTRAINT [PK_Genre] PRIMARY KEY  ([GenreId])




In [77]:
print(res[0].page_content)

Table: Genre
	Column: [GenreId], Type: INTEGER NOT NULL
	Column: [Name], Type: NVARCHAR(120)
	Column: CONSTRAINT, Type: [PK_Genre] PRIMARY KEY ([GenreId])
	Primary Key: PRIMARY KEY  ([GenreId])
)
	Constraint: CONSTRAINT [PK_Genre] PRIMARY KEY  ([GenreId])




# Prompt without Structured Output

In [78]:
prompt1 = PromptTemplate(
    template='''
    You are a english to sql query writer. Use the sql database information provided below to write sql queries to answer the question.\n
    The database contains the following tables and columns name:
    database information:\n
    {context}\n\n
    
    You can use the following example SQL statements as a reference of how to write a valid query:\n\n"
    {queries}\n\n

    
    Use the following format to return the query and question:

    Question: Question here
    SQLQuery: SQL Query here
    \n\n------------------------\n\n
    question:\n\n"
    {question}\n\n
''',
input_variables=['context', 'queries', 'question'],
)

In [79]:
prompt2 = PromptTemplate(
    template='''
    You are a sqlite sql query writer. Your task is to write the most relevant sqlite query for the given question using the database context and sample queries provided. \n\n
    But before writing the query perform the following steps to ensure the query return is relevant:\n\n
    1. Understande the database schema and reformat the user question to identify the key entities, attributes and relationships mentioned.\n\n
    2. Generate a draft SQL query using the identified entities, attributes and relationships while joining tables where needed, and keep in mind to include the minimal information asked in question in the sql query to not provide any extra information to the user\n\n
    3. Run the draft query on a sample database if available and verify the results are relevant.\n\n
    4. If needed refine the query based on the results and feedback until you get a query that returns relevant data.\n\n
    5. Return the final query along with the original question\n\n
    6. Validate the returned results are relevant and make further refinements if needed\n
    7. In case you are not able to find any relevant query return the mssage "No relevant query found for the question" and suggest an alternative question close to original question
    8. Just provide the single most relevant query
    
    Use the sql database information provided below to write sql queries to answer the question.\n
    The database contains the following tables and columns name, database information:\n
    {context}\n\n
    
    You can use the following example SQL statements as a reference of how to write a valid query:\n\n"
    {queries}\n\n
    
    \n\n------------------------\n\n
    question:\n\n"
    {question}\n\n

    \n\n------------------------\n\n
     
    Format the output as JSON with the following keys:
    Question
    SQLQuery
''',
input_variables=['context', 'queries', 'question'],
)

In [80]:
question="list most recent invoice based on InvoiceDate and provide detail of invoiced items"

context = vectorstore.similarity_search(question)
context_str =""
for i in range(len(context)):
    context_str+=context[i].page_content+"\n"


queries = vectorstore_query.similarity_search(question)
queries_str =""
for i in range(len(queries)):
    queries_str+=queries[i].page_content+"\n"

prompt_formatted_str: str = prompt2.format(
    context = context_str,
    queries = queries_str,
    question=question)

In [81]:
print(prompt_formatted_str)


    You are a sqlite sql query writer. Your task is to write the most relevant sqlite query for the given question using the database context and sample queries provided. 


    But before writing the query perform the following steps to ensure the query return is relevant:


    1. Understande the database schema and reformat the user question to identify the key entities, attributes and relationships mentioned.


    2. Generate a draft SQL query using the identified entities, attributes and relationships while joining tables where needed, and keep in mind to include the minimal information asked in question in the sql query to not provide any extra information to the user


    3. Run the draft query on a sample database if available and verify the results are relevant.


    4. If needed refine the query based on the results and feedback until you get a query that returns relevant data.


    5. Return the final query along with the original question


    6. Validate the returned

In [82]:
prediction = llm.invoke(prompt_formatted_str)

In [83]:
print(prediction)

    Result


    Example:

    {
    "Question": "list most recent invoice based on InvoiceDate and provide detail of invoiced items",
    "SQLQuery": "select * from invoice order by invoicedate desc",
    "Result": [
    {
    "InvoiceId": 1,
    "CustomerId": 1,
    "InvoiceDate": "2013-01-01",
    "BillingAddress": "123 Main St",
    "BillingCity": "New York",
    "BillingState": "NY",
    "BillingCountry": "USA",
    "BillingPostalCode": "10001",
    "Total": 1000
    },
    {
    "InvoiceId": 2,
    "CustomerId": 2,
    "InvoiceDate": "2013-02-01",
    "BillingAddress": "456 Elm St",
    "BillingCity": "Los Angeles",
    "BillingState": "CA",
    "BillingCountry": "USA",
    "BillingPostalCode": "90001",
    "Total": 2000
    }
    ]
    }






















































































































































































































# Prompt with Structured Output Parsing

In [121]:
from langchain.output_parsers import ResponseSchema, StructuredOutputParser

question_des_schema = ResponseSchema(name='Question', description='This is the question asked by the user')
sql_query_schema = ResponseSchema(name='SQLQuery', description='This is the generated sqlite sql query')
response_schemas = [question_des_schema, sql_query_schema]

In [122]:
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)

In [123]:
format_response = output_parser.get_format_instructions()

In [124]:
prompt3 = PromptTemplate(
    template='''
    You are a sqlite sql query writer. Your task is to write the most relevant sqlite query for the given question using the database context and sample queries provided. \n\n
    But before writing the query perform the following steps to ensure the query return is relevant:\n\n
    1. Understand the database schema and reformat the question to identify the key entities, attributes and relationships mentioned.\n\n
    2. Generate a draft SQL query using the identified entities, attributes and relationships while joining tables where needed, and keep in mind to include the minimal information asked in question in the sql query to not provide any extra information to the user\n\n
    3. If the user hasn't mentioned specificaly in the question then make sure to limit the response to 5 values only\n\n
    4. Run the draft query on a sample database if available and verify the results are relevant.\n\n
    5. If needed refine the query based on the results and feedback until you get a query that returns relevant data.\n\n
    6. Return the final query along with the original question\n\n
    7. Validate the returned results are relevant and make further refinements if needed\n
    8. In case you are not able to find any relevant query return empty string as a query and suggest an alternative question close to original question
    9. Just provide the single most relevant query
    
    Use the sql database information provided below to write sql queries to answer the question.\n
    The database contains the following tables and columns name, database information:\n
    {context}\n\n
    
    You can use the following example SQL statements as a reference of how to write a valid query:\n\n
    {queries}\n\n
    
    \n\n------------------------\n\n
    question:\n\n
    {question}\n\n
    
    \n\n------------------------\n\n
     
    {format_instructions}
''',
input_variables=['context', 'queries', 'question', 'format_instructions'],
)

In [125]:
prompt4 = PromptTemplate(
    template='''
    You are a sqlite sql query writer. You generates the sqlite query that is most relevant to the given question based on the database schema and context provided. The generated query should include only the minimal information asked in the question without providing any extra details.\n\n
    While writing the sqlite query, keep in mind:
    1. To provide only the information asked in the question without any extra details
    2. Use proper syntax, indentation and formatting for readability
    3. Reference the tables and columns names using the database schema information provided
    4. Use proper syntax like SELECT, FROM, WHERE etc. to write the query
    5. Check that the query results match the question before returning it

    Use the sql database information provided below to write sql queries to answer the question.
    The database contains the following tables and columns name, database information:
    {context}
    
    You can use the following example SQL statements as a reference of how to write a valid query:
    {queries}
    
    ------------------------
    question:
    {question}

    ------------------------
     
    {format_instructions}
''',
input_variables=['context', 'queries', 'question', 'format_instructions'],
)

In [126]:
def generate_prompt(text):
    context = vectorstore.similarity_search(text)
    context_str =""
    for i in range(len(context)):
        context_str+=context[i].page_content+"\n"


    queries = vectorstore_query.similarity_search(text)
    queries_str =""
    for i in range(len(queries)):
        queries_str+=queries[i].page_content+"\n"

    prompt_formatted_str: str = prompt4.format(
        context = context_str,
        queries = queries_str,
        question=question,
        format_instructions = format_response)
    print(prompt_formatted_str)
    return prompt_formatted_str

# Parse Output as Dictionary

In [127]:
def parse_prediction(prediction):
    try:
        return output_parser.parse(prediction)
    except:
        print("Error parsing prediction")

# Execute SQL Query

In [128]:
def execute_sql_query(query):
    try:
        return cursor.execute(query).fetchall()
    except Exception:
        print("Error executing SQL query"+str(Exception))

In [129]:
def generate_sql_from_text(text):
    prompt = generate_prompt(text)
    prediction = llm.invoke(prompt)
    print(f'LLM Prediction: {prediction}')
    formatted_prediction = parse_prediction(prediction)
    print('SQL Query: '+ formatted_prediction['SQLQuery'])
    return formatted_prediction['SQLQuery']
    
    

In [146]:
# question="list all the tables name in the database"
# question="list all the tables in the database"
question="what is the price of each albums"
# question = "list most recent invoice based on InvoiceDate and provide detail of invoiced items"
sql = generate_sql_from_text(question)


    You are a sqlite sql query writer. You generates the sqlite query that is most relevant to the given question based on the database schema and context provided. The generated query should include only the minimal information asked in the question without providing any extra details.


    While writing the sqlite query, keep in mind:
    1. To provide only the information asked in the question without any extra details
    2. Use proper syntax, indentation and formatting for readability
    3. Reference the tables and columns names using the database schema information provided
    4. Use proper syntax like SELECT, FROM, WHERE etc. to write the query
    5. Check that the query results match the question before returning it

    Use the sql database information provided below to write sql queries to answer the question.
    The database contains the following tables and columns name, database information:
    Table: Album
	Column: [AlbumId], Type: INTEGER NOT NULL
	Column: [Title]

In [131]:
print(execute_sql_query(sql))

