## LLM Flow using RAG

#### Imports

In [1]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatVertexAI
from langchain.vectorstores import FAISS
from google.cloud import bigquery
from typing import List
from tqdm import tqdm
import logging
import json
import os 

##### Setup logging

In [2]:
logger = logging.getLogger('langchain')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

#### Setup essentials

In [3]:
SERVICE_ACCOUNT_KEY_PATH = './../credentials/vai-key.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH

In [4]:
PROJECT = 'arun-genai-bb'
LOCATION = 'us-central1'
MODEL_NAME = 'codechat-bison@latest'

In [5]:
llm = ChatVertexAI(project=PROJECT, 
                   location=LOCATION, 
                   model_name=MODEL_NAME,
                   temperature=0.0, 
                   max_output_tokens=512)

In [6]:
class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):
    model_name = 'textembedding-gecko'
    max_batch_size = 5
    
    def embed_segments(self, segments: List) -> List:
        embeddings = []
        for i in tqdm(range(0, len(segments), self.max_batch_size)):
            batch = segments[i: i+self.max_batch_size]
            embeddings.extend(self.client.get_embeddings(batch))
        return [embedding.values for embedding in embeddings]
    
    def embed_query(self, query: str) -> List:
        embeddings = self.client.get_embeddings([query])
        return embeddings[0].values

In [7]:
embedding = MyVertexAIEmbeddings()

In [8]:
query = "Provide a list of all flight reservations from October 10th to October 15th, 2023"

#### Step 1: Embed and index tables info

In [9]:
documents = JSONLoader(file_path='./../data/rag-schema/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()

In [10]:
db = FAISS.from_documents(documents=documents, embedding=embedding)

#### Step 2: Match indexed tables embedding to incoming query

In [11]:
retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})

In [12]:

matched_documents = retriever.get_relevant_documents(query=query)


In [13]:
matched_tables = []

for document in matched_documents:
    page_content = document.page_content
    page_content = json.loads(page_content)
    dataset_name = page_content['dataset_name']
    table_name = page_content['table_name']
    matched_tables.append(f'{dataset_name}.{table_name}')

logger.info(f'Matched tables = {matched_tables}')
    

Matched tables = ['flight_reservations.reservations', 'flight_reservations.transactions', 'flight_reservations.flights', 'hotel_reservations.reservations', 'hotel_reservations.inventory']


#### Step 3: Embed and index columns info
 

In [14]:
documents = JSONLoader(file_path='./../data/rag-schema/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
db = FAISS.from_documents(documents=documents, embedding=embedding)

In [15]:
documents[0]

Document(page_content='{"dataset_name": "hotel_reservations", "table_name": "hotels", "column_name": "hotel_id", "description": "A unique identifier assigned to each hotel.", "usage": "This ID helps in maintaining a distinct record for each hotel and acts as a primary key. It\'s also used for referencing in other tables like Rooms.", "data_type": "INT64"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 1})

#### Step 4: Match indexed columns embedding to incoming query

In [16]:
search_kwargs = {
    'k': 20
}

retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)

In [17]:
matched_columns = retriever.get_relevant_documents(query=query)
logger.info(matched_columns)

[Document(page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "column_name": "reservation_datetime", "description": "Timestamp of when the reservation was made.", "usage": "Helps track reservation history and manage bookings.", "data_type": "DATETIME"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 62}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "departure_datetime", "description": "The departure time of the flight.", "usage": "Informs users and helps them plan their travel.", "data_type": "DATETIME"}', metadata={'source': '/Users/arunpshankar/Desktop/Projects/LLM-Text-to-SQL-Architectures/data/rag-schema/columns.jsonl', 'seq_num': 55}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "origin", "description": "The departure location of the flight.", "usage

In [18]:
matched_columns_filtered = []

# LangChain filters does not support multiple values at the moment
for i, column in enumerate(matched_columns):
    page_content = json.loads(column.page_content)
    dataset_name = page_content['dataset_name']
    if dataset_name == 'flight_reservations':
        matched_columns_filtered.append(page_content)

logger.info(matched_columns_filtered)


[{'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'reservation_datetime', 'description': 'Timestamp of when the reservation was made.', 'usage': 'Helps track reservation history and manage bookings.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'departure_datetime', 'description': 'The departure time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'origin', 'description': 'The departure location of the flight.', 'usage': 'Helps users find flights based on their travel plans.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'destination', 'description': 'The arrival location of the flight.', 'usage': 'Used to find flights and plan journeys.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservation

In [19]:
matched_columns_cleaned = []

for doc in matched_columns_filtered:
    dataset_name = doc['dataset_name']
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    matched_columns_cleaned.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')
    
matched_columns_cleaned = '\n'.join(matched_columns_cleaned)
logger.info(matched_columns_cleaned)

dataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=transaction_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=customers|column_name=date_of_birth|data_type=DATE
dataset_name=flight_reservations|table_name=customers|column_name=created_at|data_type=DATETIME
dataset_name=flight_reservations|table_name=reservations|column_name=status|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=reservation_id|data_type=INT64
dataset_name=fl

#### Step 5: Text-to-SQL generation

In [20]:
messages = []

In [21]:
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [22]:
human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA: 
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above. 
The goal is to determine the availability of hotels based on the provided info. 

IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. 
IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. 
"""

In [23]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [24]:
chat_prompt = ChatPromptTemplate.from_messages(messages)

In [25]:
request = chat_prompt.format_prompt(query=query,
                                    matched_schema=matched_columns_cleaned).to_messages()

In [26]:
%%time 

response = llm(request)
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(sql)

SELECT 
  r.reservation_id,
  r.flight_id,
  r.customer_id,
  r.status,
  r.reservation_datetime,
  f.departure_datetime,
  f.origin,
  f.destination,
  f.arrival_datetime,
  f.carrier
FROM flight_reservations.reservations AS r
JOIN flight_reservations.flights AS f
ON r.flight_id = f.flight_id
WHERE r.reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15';


CPU times: user 40.2 ms, sys: 4.46 ms, total: 44.6 ms
Wall time: 2.76 s


### Step 6: Execute the generated SQL query in BigQuery

In [27]:
bq = bigquery.Client()


In [28]:
df = bq.query(sql).to_dataframe()
df

Unnamed: 0,reservation_id,flight_id,customer_id,status,reservation_datetime,departure_datetime,origin,destination,arrival_datetime,carrier
0,6,6,6,Confirmed,2023-10-10 10:00:00,2023-11-25 06:00:00,SEA,JFK,2023-11-25 14:30:00,United
1,7,7,6,Confirmed,2023-10-12 11:30:00,2023-11-27 20:00:00,JFK,MIA,2023-11-27 23:30:00,American
