In [1]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatVertexAI
from langchain.vectorstores import FAISS
from google.cloud import bigquery
from typing import List
from tqdm import tqdm
import logging
import json
import os 

In [2]:
logger = logging.getLogger('langchain')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [3]:
SERVICE_ACCOUNT_KEY_PATH = './credentials/vai-key.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH

In [4]:
PROJECT = 'arun-genai-bb'
LOCATION = 'us-central1'
MODEL_NAME = 'chat-bison'

In [5]:
llm = ChatVertexAI(project=PROJECT, 
                   location=LOCATION, 
                   model_name=MODEL_NAME,
                   temperature=0.0, 
                   max_output_tokens=512)

In [6]:
class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):
    model_name = 'textembedding-gecko'
    max_batch_size = 5
    
    def embed_segments(self, segments: List) -> List:
        embeddings = []
        for i in tqdm(range(0, len(segments), self.max_batch_size)):
            batch = segments[i: i+self.max_batch_size]
            embeddings.extend(self.client.get_embeddings(batch))
        return [embedding.values for embedding in embeddings]
    
    def embed_query(self, query: str) -> List:
        embeddings = self.client.get_embeddings([query])
        return embeddings[0].values

In [7]:
embedding = MyVertexAIEmbeddings()

In [8]:
documents = JSONLoader(file_path='./DATA/RAG/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()

In [9]:

db = FAISS.from_documents(documents=documents, embedding=embedding)

In [10]:
retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})

In [11]:
query = "make a room reservation for 3 nights Miami"
matched_docs = retriever.get_relevant_documents(query=query)


In [12]:
matched_tables = []

for doc in matched_docs:
    page_content = doc.page_content
    page_content = json.loads(page_content)
    table_name = page_content['table_name']
    matched_tables.append(table_name)

logger.info(f'Matched tables = {matched_tables}')
    

Matched tables = ['reservations', 'check_ins_outs', 'rooms', 'inventory', 'payments']


#### Columns 

In [13]:
documents = JSONLoader(file_path='./DATA/RAG/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
db = FAISS.from_documents(documents=documents, embedding=embedding)

In [14]:
documents[0]

Document(page_content='{"dataset_name": "hotel_reservations", "table_name": "hotels", "column_name": "hotel_id", "description": "A unique identifier assigned to each hotel.", "usage": "This ID helps in maintaining a distinct record for each hotel and acts as a primary key. It\'s also used for referencing in other tables like Rooms.", "data_type": "INT64"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/bq-sql-agent/DATA/RAG/columns.jsonl', 'seq_num': 1})

In [15]:
# hard coding this - ideally this should be all table names matched based on vector search pass 1
# simplified for the sake of PoC and brevity
filter = ['hotels', 'reservations', 'rooms']

search_kwargs = {
    'k': 20,  # TODO derive the optimal value for k automatically 
}

retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)

In [16]:
matched_docs = retriever.get_relevant_documents(query=query)
filtered_matched_docs = []
for i, doc in enumerate(matched_docs):
    page_content = json.loads(doc.page_content)
    table_name = page_content['table_name']
    if table_name in filter:
        filtered_matched_docs.append(page_content)


In [17]:
filtered_matched_docs

final = []
for doc in filtered_matched_docs:
    dataset_name = doc['dataset_name']
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    final.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')
final = '\n'.join(final)

In [18]:
messages = []

In [19]:
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [20]:
human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA: 
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above. 
The goal is to determine the availability of hotels based on the provided info. 

IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. 
IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. 
"""

In [21]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [22]:
chat_prompt = ChatPromptTemplate.from_messages(messages)

In [23]:
request = chat_prompt.format_prompt(query=query,
                                    matched_schema=final).to_messages()

In [24]:
logger.info(request)

[SystemMessage(content='You are a SQL master expert capable of writing complex SQL query in BigQuery.', additional_kwargs={}), HumanMessage(content="Given the following inputs:\nUSER_QUERY:\n--\nmake a room reservation for 3 nights Miami\n--\nMATCHED_SCHEMA: \n--\ndataset_name=hotel_reservations|table_name=reservations|column_name=start_date|data_type=DATE\ndataset_name=hotel_reservations|table_name=rooms|column_name=hotel_id|data_type=INT64\ndataset_name=hotel_reservations|table_name=rooms|column_name=room_id|data_type=INT64\ndataset_name=hotel_reservations|table_name=reservations|column_name=room_id|data_type=INT64\ndataset_name=hotel_reservations|table_name=rooms|column_name=availability|data_type=INT64\ndataset_name=hotel_reservations|table_name=rooms|column_name=room_type|data_type=STRING\ndataset_name=hotel_reservations|table_name=reservations|column_name=end_date|data_type=DATE\ndataset_name=hotel_reservations|table_name=reservations|column_name=reservation_id|data_type=INT64\nd

In [25]:
%%time 

response = llm(request)

CPU times: user 41.4 ms, sys: 6.64 ms, total: 48.1 ms
Wall time: 4.89 s


In [26]:
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(sql)

SELECT 
  r.hotel_id, 
  r.room_id, 
  r.room_type, 
  r.price_per_night, 
  h.location
FROM hotel_reservations.rooms AS r
JOIN hotel_reservations.hotels AS h
ON r.hotel_id = h.hotel_id
WHERE h.location = 'Miami'
AND r.availability >= 3
AND NOT EXISTS (
  SELECT * 
  FROM hotel_reservations.reservations AS res 
  WHERE res.room_id = r.room_id 
  AND (
    (res.start_date BETWEEN DATE_ADD(CURRENT_DATE(), INTERVAL -3 DAY) AND DATE_ADD(CURRENT_DATE(), INTERVAL 3 DAY))
    OR (res.end_date BETWEEN DATE_ADD(CURRENT_DATE(), INTERVAL -3 DAY) AND DATE_ADD(CURRENT_DATE(), INTERVAL 3 DAY))
    OR (DATE_ADD(CURRENT_DATE(), INTERVAL -3 DAY) BETWEEN res.start_date AND res.end_date)
    OR (DATE_ADD(CURRENT_DATE(), INTERVAL 3 DAY) BETWEEN res.start_date AND res.end_date)
  )
)


In [27]:
bq = bigquery.Client()


In [28]:
query_job = bq.query(sql)
logger.info(query_job.to_dataframe())

   hotel_id  room_id room_type  price_per_night location
0         2        3    Deluxe            180.0    Miami
1         2        4     Suite            280.0    Miami


In [29]:
sql_result = []

for row in query_job:
    sql_result.append(f"Room type = {row.room_type}")
    sql_result.append(f"Price per night in $ = {row.price_per_night}")
    sql_result.append("")

sql_result = "\n".join(sql_result).strip()
logger.info(sql_result)

Room type = Deluxe
Price per night in $ = 180.0

Room type = Suite
Price per night in $ = 280.0
