In [5]:
import boto3
from langchain_aws import ChatBedrock
from botocore.config import Config
import warnings
warnings.filterwarnings("ignore")

region = "us-west-2"
config = Config(
    region_name=region,
    signature_version = "v4",
    retries={
        "max_attempts":3,
        "mode" : "standard",
    }
)
bedrock_rt = boto3.client("bedrock-runtime", config=config)

sonnet_model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

model_kwargs = {
    "max_tokens" : 4096,
    "temperature" : 0.0,
    "stop_sequences" : ["Human"],
}

llm = ChatBedrock(
    client = bedrock_rt,
    model_id = sonnet_model_id,
    model_kwargs = model_kwargs,
)

In [2]:
from langchain_community.utilities.sql_database import SQLDatabase

db = SQLDatabase.from_uri(
    database_uri="sqlite:///Chinook.db",
)

In [41]:
print(db.get_table_info())


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

In [14]:
import re, sqlglot

def format_query(message):
    text = message.content.strip()
    text = " ".join(text.split("\n"))
    
    tag = "sql"
    pattern = "<"+tag+">(.*?)</"+tag+">"
    matches = re.findall(pattern, text)
    if len(matches) == 0:
        return None
    raw_sql_query = matches[0].strip()
    sql_query = sqlglot.transpile(raw_sql_query, write='postgres', identify=True)[0]
    return sql_query

In [6]:
print(db.get_table_names())

['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [44]:
table_schema = """

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalves	Embraer - Empresa Brasileira de Aeronáutica S.A.	Av. Brigadeiro Faria Lima, 2170	São José dos Campos	SP	Brazil	12227-000	+55 (12) 3923-5555	+55 (12) 3923-5566	luisg@embraer.com.br	3
2	Leonie	Köhler	None	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	+49 0711 2842222	None	leonekohler@surfeu.de	5
3	François	Tremblay	None	1498 rue Bélanger	Montréal	QC	Canada	H2G 1A7	+1 (514) 721-4711	None	ftremblay@gmail.com	3
*/


CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATETIME, 
	"HireDate" DATETIME, 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60), 
	PRIMARY KEY ("EmployeeId"), 
	FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId	LastName	FirstName	Title	ReportsTo	BirthDate	HireDate	Address	City	State	Country	PostalCode	Phone	Fax	Email
1	Adams	Andrew	General Manager	None	1962-02-18 00:00:00	2002-08-14 00:00:00	11120 Jasper Ave NW	Edmonton	AB	Canada	T5K 2N1	+1 (780) 428-9482	+1 (780) 428-3457	andrew@chinookcorp.com
2	Edwards	Nancy	Sales Manager	1	1958-12-08 00:00:00	2002-05-01 00:00:00	825 8 Ave SW	Calgary	AB	Canada	T2P 2T3	+1 (403) 262-3443	+1 (403) 262-3322	nancy@chinookcorp.com
3	Peacock	Jane	Sales Support Agent	2	1973-08-29 00:00:00	2002-04-01 00:00:00	1111 6 Ave SW	Calgary	AB	Canada	T2P 5M5	+1 (403) 262-3443	+1 (403) 262-6712	jane@chinookcorp.com
*/


CREATE TABLE "Genre" (
	"GenreId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("GenreId")
)

/*
3 rows from Genre table:
GenreId	Name
1	Rock
2	Jazz
3	Metal
*/


CREATE TABLE "Invoice" (
	"InvoiceId" INTEGER NOT NULL, 
	"CustomerId" INTEGER NOT NULL, 
	"InvoiceDate" DATETIME NOT NULL, 
	"BillingAddress" NVARCHAR(70), 
	"BillingCity" NVARCHAR(40), 
	"BillingState" NVARCHAR(40), 
	"BillingCountry" NVARCHAR(40), 
	"BillingPostalCode" NVARCHAR(10), 
	"Total" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("InvoiceId"), 
	FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId	CustomerId	InvoiceDate	BillingAddress	BillingCity	BillingState	BillingCountry	BillingPostalCode	Total
1	2	2021-01-01 00:00:00	Theodor-Heuss-Straße 34	Stuttgart	None	Germany	70174	1.98
2	4	2021-01-02 00:00:00	Ullevålsveien 14	Oslo	None	Norway	0171	3.96
3	8	2021-01-03 00:00:00	Grétrystraat 63	Brussels	None	Belgium	1000	5.94
*/


CREATE TABLE "InvoiceLine" (
	"InvoiceLineId" INTEGER NOT NULL, 
	"InvoiceId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	"Quantity" INTEGER NOT NULL, 
	PRIMARY KEY ("InvoiceLineId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId	InvoiceId	TrackId	UnitPrice	Quantity
1	1	2	0.99	1
2	1	4	0.99	1
3	2	6	0.99	1
*/


CREATE TABLE "MediaType" (
	"MediaTypeId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("MediaTypeId")
)

/*
3 rows from MediaType table:
MediaTypeId	Name
1	MPEG audio file
2	Protected AAC audio file
3	Protected MPEG-4 video file
*/


CREATE TABLE "Playlist" (
	"PlaylistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("PlaylistId")
)

/*
3 rows from Playlist table:
PlaylistId	Name
1	Music
2	Movies
3	TV Shows
*/


CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/


CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
	FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
	FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
3 rows from Track table:
TrackId	Name	AlbumId	MediaTypeId	GenreId	Composer	Milliseconds	Bytes	UnitPrice
1	For Those About To Rock (We Salute You)	1	1	1	Angus Young, Malcolm Young, Brian Johnson	343719	11170334	0.99
2	Balls to the Wall	2	2	1	U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann	342562	5510424	0.99
3	Fast As a Shark	3	2	1	F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman	230619	3990994	0.99
*/

"""

In [69]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

sql_prompt_template = """You are an assistant with expertise to write syntactically correct SQL queries for given database given a human's question.You have 
to provide upto limit of 5 unless not specified by the user.

You have access to the following table schemas between the XML tags <table></table> and based on that form the SQL query:
<table>
{table_schema}
</table>

Make sure that you only choose tables from the <table></table> XML tags. Do not create any new tables. 
Remember to respond with "NO QUERY TO GENERATE" when the question does not correspond to any table description given. In the SQL Query, use only necessary columns specified in the schema.
Remember to always place the complete sql query in one line within <sql></sql> XML tags.

Question: {user_query}

SQLQuery: """


In [70]:
sql_prompt = PromptTemplate(
    input_variables=[
        "user_query",
    ],
    partial_variables={
        "table_schema": table_schema,
        "top_k" : 5,
    },
    template=sql_prompt_template
)

In [71]:
text2sql_chain = sql_prompt | llm | format_query | StrOutputParser()

In [72]:
response = text2sql_chain.invoke({"user_query" : "Give me the count of Artists"})
print(response)

SELECT COUNT(*) FROM "Artist"


In [73]:
db.run(response)

'[(275,)]'

In [74]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only."""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | llm | StrOutputParser()

full_chain = text2sql_chain| validation_chain

In [76]:
query = full_chain.invoke({"user_query" : "Give me frequency of artists against their albums"})

In [77]:
print(query)

SELECT "a"."Name" AS "Artist", COUNT("al"."AlbumId") AS "AlbumCount" 
FROM "Artist" AS "a"
JOIN "Album" AS "al" ON "a"."ArtistId" = "al"."ArtistId"
GROUP BY "a"."Name"




In [78]:
response = db.run(query)
print(response)

[('AC/DC', 2), ('Aaron Copland & London Symphony Orchestra', 1), ('Aaron Goldberg', 1), ('Academy of St. Martin in the Fields & Sir Neville Marriner', 1), ('Academy of St. Martin in the Fields Chamber Ensemble & Sir Neville Marriner', 1), ('Academy of St. Martin in the Fields, John Birch, Sir Neville Marriner & Sylvia McNair', 1), ('Academy of St. Martin in the Fields, Sir Neville Marriner & Thurston Dart', 1), ('Accept', 2), ('Adrian Leaper & Doreen de Feis', 1), ('Aerosmith', 1), ('Aisha Duo', 1), ('Alanis Morissette', 1), ('Alberto Turco & Nova Schola Gregoriana', 1), ('Alice In Chains', 1), ('Amy Winehouse', 2), ('Anne-Sophie Mutter, Herbert Von Karajan & Wiener Philharmoniker', 1), ('Antal Doráti & London Symphony Orchestra', 1), ('Antônio Carlos Jobim', 2), ('Apocalyptica', 1), ('Aquaman', 1), ('Audioslave', 3), ('BackBeat', 1), ('Barry Wordsworth & BBC Concert Orchestra', 1), ('Battlestar Galactica', 2), ('Battlestar Galactica (Classic)', 1), ('Berliner Philharmoniker & Hans Ros

In [79]:
from sqlalchemy import create_engine
db_engine = create_engine("sqlite:///Chinook.db")

In [68]:
from langchain_community.embeddings import BedrockEmbeddings

bedrock_client = boto3.client(service_name='bedrock-runtime', 
                              region_name='us-east-1')
embeddings_model = BedrockEmbeddings(model_id="amazon.titan-embed-text-v1",
                                       client=bedrock_client)

***Creating a RAG Layer to query for responses faster***


In [None]:
rag_template = """
You are an assistant with access to a database. You have received the following question from a user:
Context: {context}

Based on the information from the query results, provide a detailed and accurate answer to the user's question.
Answer:
"""

def preprocess_data(user_query, sql_query):
    data = get_data(sql_query)
    documents = [Document(page_content=str(row.to_dict())) for _, row in data.iterrows()]
    document_contents = [doc.page_content for doc in documents]
    document_embeddings = embeddings_model.embed_documents(document_contents)
    db_vector = FAISS.from_documents(documents, document_embeddings)
    query_embedding = embeddings_model.embed_query(user_query)
    best_result = db_vector.similarity_search_by_vector(query_embedding, k=5)
    return best_result

def generate_context(inputs):
    user_query = inputs['user_query']
    sql_query = inputs['sql_query']
    best_result = preprocess_data(user_query, sql_query)
    context = "\n".join([doc.page_content for doc in best_result])
    return {"context": context}

def invoke_full_chain(user_query):
    return full_chain.invoke({"user_query": user_query})

# Define the chain
rag_chain = (
    RunnableLambda(lambda x: {"user_query": x['user_query'], "sql_query": invoke_full_chain(x['user_query'])})
    | RunnableLambda(lambda x: generate_context(x))
    | RunnableLambda(lambda x: rag_template.format(**x))
    | llm
    | StrOutputParser()
)

# Complete the chain
complete_chain = full_chain | rag_chain

# Invoke the chain
result = complete_chain.invoke({"user_query": "Give all the artist"})
print(result)
