In [38]:
import boto3
from langchain_aws import ChatBedrock
from botocore.config import Config
import warnings
warnings.filterwarnings("ignore")

region = "us-west-2"
config = Config(
    region_name=region,
    signature_version = "v4",
    retries={
        "max_attempts":3,
        "mode" : "standard",
    }
)
bedrock_rt = boto3.client("bedrock-runtime", config=config)

sonnet_model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

model_kwargs = {
    "max_tokens" : 4096,
    "temperature" : 0.0,
    "stop_sequences" : ["Human"],
}

llm = ChatBedrock(
    client = bedrock_rt,
    model_id = sonnet_model_id,
    model_kwargs = model_kwargs,
)

In [39]:
from langchain_community.utilities.sql_database import SQLDatabase

db = SQLDatabase.from_uri(
    database_uri="sqlite:///Chinook.db",
)


In [40]:
import re
import sqlglot
from langchain_core.prompt_values import StringPromptValue

def format_query(message):
    if isinstance(message, StringPromptValue):
        text = message.text.strip()
    elif isinstance(message, str):
        text = message.strip()
    else:
        text = str(message).strip()
    
    text = " ".join(text.split("\n"))
    
    tag = "answer"
    pattern = "<"+tag+">(.*?)</"+tag+">"
    matches = re.findall(pattern, text)
    if len(matches) == 0:
        return None
    raw_sql_query = matches[0].strip()
    sql_query = sqlglot.transpile(raw_sql_query, write='sqlite', identify=True)[0]
    return sql_query

In [41]:
print(db.get_table_names())

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [42]:
table_schema = """

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
5 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
4	Let There Be Rock	1
5	Big Ones	3
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
5 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
4	Alanis Morissette
5	Alice In Chains
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
5 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	leonekohler@surfeu.de	5
3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	ftremblay@gmail.com	3
4	Bjørn	Hansen	None	Ullevålsveien 14	Oslo	None	Norway	0171	+47 22 44 22 22	None	bjorn.hansen@yahoo.no	4
5	František	Wichterlová	JetBrains s.r.o.	Klanova 9/506	Prague	None	Czech Republic	14700	+420 2 4172 5555	+420 2 4172 5555	frantisekw@jetbrains.com	4
*/


CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60), 
	PRIMARY KEY ("EmployeeId"), 
	FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
5 rows from Employee table:
EmployeeId	LastName	FirstName	Title	ReportsTo	BirthDate	HireDate	Address	City	State	Country	PostalCode	Phone	Fax	Email
1	Adams	Andrew	General Manager	None	1962-02-18 00:00:00	2002-08-14 00:00:00	11120 Jasper Ave NW	Edmonton	AB	Canada	T5K 2N1	+1 (780) 428-9482	+1 (780) 428-3457	andrew@chinookcorp.com
2	Edwards	Nancy	Sales Manager	1	1958-12-08 00:00:00	2002-05-01 00:00:00	825 8 Ave SW	Calgary	AB	Canada	T2P 2T3	+1 (403) 262-3443	+1 (403) 262-3322	nancy@chinookcorp.com
3	Peacock	Jane	Sales Support Agent	2	1973-08-29 00:00:00	2002-04-01 00:00:00	1111 6 Ave SW	Calgary	AB	Canada	T2P 5M5	+1 (403) 262-3443	+1 (403) 262-6712	jane@chinookcorp.com
4	Park	Margaret	Sales Support Agent	2	1947-09-19 00:00:00	2003-05-03 00:00:00	683 10 Street SW	Calgary	AB	Canada	T2P 5G3	+1 (403) 263-4423	+1 (403) 263-4289	margaret@chinookcorp.com
5	Johnson	Steve	Sales Support Agent	2	1965-03-03 00:00:00	2003-10-17 00:00:00	7727B 41 Ave	Calgary	AB	Canada	T3B 1Y7	1 (780) 836-9987	1 (780) 836-9543	steve@chinookcorp.com
*/


CREATE TABLE "Genre" (
	"GenreId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("GenreId")
)

/*
5 rows from Genre table:
GenreId	Name
1	Rock
2	Jazz
3	Metal
4	Alternative & Punk
5	Rock And Roll
*/


CREATE TABLE "Invoice" (
	"InvoiceId" INTEGER NOT NULL, 
	"CustomerId" INTEGER NOT NULL, 
	"InvoiceDate" DATETIME NOT NULL, 
	"BillingAddress" NVARCHAR(70), 
	"BillingCity" NVARCHAR(40), 
	"BillingState" NVARCHAR(40), 
	"BillingCountry" NVARCHAR(40), 
	"BillingPostalCode" NVARCHAR(10), 
	"Total" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("InvoiceId"), 
	FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
5 rows from Invoice table:
InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
4	14	2021-01-06 00:00:00	8210 111 ST NW	Edmonton	AB	Canada	T6G 2C7	8.91
5	23	2021-01-11 00:00:00	69 Salem Street	Boston	MA	USA	2113	13.86
*/


CREATE TABLE "InvoiceLine" (
	"InvoiceLineId" INTEGER NOT NULL, 
	"InvoiceId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	"Quantity" INTEGER NOT NULL, 
	PRIMARY KEY ("InvoiceLineId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
5 rows from InvoiceLine table:
InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
1	1	2	0.99	1
2	1	4	0.99	1
3	2	6	0.99	1
4	2	8	0.99	1
5	2	10	0.99	1
*/


CREATE TABLE "MediaType" (
	"MediaTypeId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("MediaTypeId")
)

/*
5 rows from MediaType table:
MediaTypeId	Name
1	MPEG audio file
2	Protected AAC audio file
3	Protected MPEG-4 video file
4	Purchased AAC audio file
5	AAC audio file
*/


CREATE TABLE "Playlist" (
	"PlaylistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("PlaylistId")
)

/*
5 rows from Playlist table:
PlaylistId	Name
1	Music
2	Movies
3	TV Shows
4	Audiobooks
5	90’s Music
*/


CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
5 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
1	3391
1	3392
*/


CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
	FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
	FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
5 rows from Track table:
TrackId	Name	AlbumId	MediaTypeId	GenreId	Composer	Milliseconds	Bytes	UnitPrice
1	For Those About To Rock (We Salute You)	1	1	1	Angus Young, Malcolm Young, Brian Johnson	343719	11170334	0.99
2	Balls to the Wall	2	2	1	U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann	342562	5510424	0.99
3	Fast As a Shark	3	2	1	F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman	230619	3990994	0.99
4	Restless and Wild	3	2	1	F. Baltes, R.A. Smith-Diesel, S. Kaufman, U. Dirkscneider & W. Hoffman	252051	4331779	0.99
5	Princess of the Dawn	3	2	1	Deaffy & R.A. Smith-Diesel	375418	6290521	0.99
*/

"""

In [43]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

sql_prompt_template = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Only use the following tables:
{table_info}

Question: {question}

Use the following format:


Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Put the answer in between XML tags <answer></answer> 
"""


In [44]:
sql_prompt = PromptTemplate(
    input_variables=[
        "question",
    ],
    partial_variables={
        "table_info": db.get_table_info().strip(),
    },
    template=sql_prompt_template
)

In [47]:
text2sql_chain = sql_prompt | llm | StrOutputParser()

In [53]:
response = text2sql_chain.invoke({"question" : "Give me 5 different Artists with their track"})
print(response)

Question: Give me 5 different Artists with their tracks

SQLQuery:
SELECT "Artist"."Name" AS "Artist Name", "Track"."Name" AS "Track Name"
FROM "Artist"
JOIN "Album" ON "Artist"."ArtistId" = "Album"."ArtistId"
JOIN "Track" ON "Album"."AlbumId" = "Track"."AlbumId"
ORDER BY "Artist"."Name"
LIMIT 5;

SQLResult:
Artist Name|Track Name
AC/DC|For Those About To Rock (We Salute You)
Accept|Balls to the Wall
Accept|Fast As a Shark
Aerosmith|Restless and Wild
Aerosmith|Restless and Wild

<answer>
The query joins the Artist, Album, and Track tables to retrieve the Artist Name and Track Name for up to 5 different artists. The results show:

AC/DC with the track "For Those About To Rock (We Salute You)"
Accept with the tracks "Balls to the Wall" and "Fast As a Shark"
Aerosmith with the track "Restless and Wild" (appearing twice due to the LIMIT 5)
</answer>




In [72]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only."""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | llm 

full_chain = text2sql_chain| validation_chain

In [57]:
query = full_chain.invoke({"question" : "Give me 5 different Artists with their track"})

In [58]:
print(query)

SELECT "Artist"."Name" AS "Artist Name", "Track"."Name" AS "Track Name"
FROM "Artist"
JOIN "Album" ON "Artist"."ArtistId" = "Album"."ArtistId"
JOIN "Track" ON "Album"."AlbumId" = "Track"."AlbumId"
ORDER BY "Artist"."Name"
LIMIT 5;




In [59]:
response = db.run(query)
print(response)

[('AC/DC', 'For Those About To Rock (We Salute You)'), ('AC/DC', 'Put The Finger On You'), ('AC/DC', "Let's Get It Up"), ('AC/DC', 'Inject The Venom'), ('AC/DC', 'Snowballed')]


In [60]:
from sqlalchemy import create_engine
db_engine = create_engine("sqlite:///Chinook.db")

In [61]:
from langchain_community.embeddings import BedrockEmbeddings

bedrock_client = boto3.client(service_name='bedrock-runtime', 
                              region_name='us-east-1')
embeddings_model = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1",
                                       client=bedrock_client)

In [62]:
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
import pandas as pd

data_query = """
SELECT "Artist"."Name" AS "Artist Name", "Track"."Name" AS "Track Name"
FROM "Artist"
JOIN "Album" ON "Artist"."ArtistId" = "Album"."ArtistId"
JOIN "Track" ON "Album"."AlbumId" = "Track"."AlbumId"
ORDER BY "Artist"."Name"
LIMIT 5;
"""

data = pd.read_sql(data_query,db_engine)

documents = [Document(page_content=str(row.to_dict())) for _, row in data.iterrows()]

embeddings = embeddings_model.embed_documents([doc.page_content for doc in documents])
db_vector = FAISS.from_documents(documents,embeddings_model)

In [65]:
user_query = "Give me 5 different Artists with their track"
query_embedding = embeddings_model.embed_query(user_query)
best_result = db_vector.similarity_search_by_vector(query_embedding,k = 5)

best_result

[Document(page_content='{\'Artist Name\': \'AC/DC\', \'Track Name\': "Let\'s Get It Up"}'),
 Document(page_content="{'Artist Name': 'AC/DC', 'Track Name': 'Snowballed'}"),
 Document(page_content="{'Artist Name': 'AC/DC', 'Track Name': 'Inject The Venom'}"),
 Document(page_content="{'Artist Name': 'AC/DC', 'Track Name': 'Put The Finger On You'}"),
 Document(page_content="{'Artist Name': 'AC/DC', 'Track Name': 'For Those About To Rock (We Salute You)'}")]