<a href="https://colab.research.google.com/github/GCPQuantum/LLM-Text-to-SQL-Architectures/blob/main/02-Pattern-II/00-llm-flow-with-rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## LLM Flow using RAG

#### Imports

In [2]:
!pip install langchain
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatVertexAI
from langchain.vectorstores import FAISS
from google.cloud import bigquery
from typing import List
from tqdm import tqdm
import logging
import json
import os

Collecting langchain
  Downloading langchain-0.1.11-py3-none-any.whl (807 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m807.5/807.5 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.4-py3-none-any.whl (28 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)
Collecting langchain-community<0.1,>=0.0.25 (from langchain)
  Downloading langchain_community-0.0.27-py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-core<0.2,>=0.1.29 (from langchain)
  Downloading langchain_core-0.1.30-py3-none-any.whl (256 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m256.9/256.9 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain-text-splitters<0.1,>=0.0.1 (from langchain)
  Downloa

##### Setup logging

In [3]:
logger = logging.getLogger('langchain')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

#### Setup essentials

In [6]:
SERVICE_ACCOUNT_KEY_PATH = '/content/gcpquantummain-c2330e59b0c4.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH

In [12]:
PROJECT = 'gcpquantummain'
LOCATION = 'us-central1'
MODEL_NAME = 'codechat-bison'

In [15]:
llm = ChatVertexAI(project=PROJECT,
                   location=LOCATION,
                   model_name=MODEL_NAME,
                   temperature=0.0,
                   max_output_tokens=512)

llm

ChatVertexAI(project='gcpquantummain', model_name='codechat-bison', client=<vertexai.language_models.CodeChatModel object at 0x78c491bae7d0>, max_output_tokens=512)

In [16]:
class MyVertexAIEmbeddings(VertexAIEmbeddings, Embeddings):
    model_name = 'textembedding-gecko'
    max_batch_size = 5

    def embed_segments(self, segments: List) -> List:
        embeddings = []
        for i in tqdm(range(0, len(segments), self.max_batch_size)):
            batch = segments[i: i+self.max_batch_size]
            embeddings.extend(self.client.get_embeddings(batch))
        return [embedding.values for embedding in embeddings]

    def embed_query(self, query: str) -> List:
        embeddings = self.client.get_embeddings([query])
        return embeddings[0].values

In [29]:
embedding = MyVertexAIEmbeddings()
embedding



MyVertexAIEmbeddings(project=None, location='us-central1', request_parallelism=5, max_retries=6, stop=None, model_name='textembedding-gecko@001', client=<vertexai.language_models.TextEmbeddingModel object at 0x78c47efcb7c0>, client_preview=None, temperature=0.0, max_output_tokens=128, top_p=0.95, top_k=40, credentials=None, n=1, streaming=False, instance={'max_batch_size': 250, 'batch_size': 250, 'min_batch_size': 5, 'min_good_batch_size': 5, 'lock': <unlocked _thread.lock object at 0x78c49022b900>, 'batch_size_validated': False, 'task_executor': <concurrent.futures.thread.ThreadPoolExecutor object at 0x78c57c308910>, 'embeddings_task_type_supported': False}, show_progress_bar=False, max_batch_size=5)

In [18]:
query = "Provide a list of all flight reservations from October 10th to October 15th, 2023"

#### Step 1: Embed and index tables info

In [26]:
!pip install jq

documents = JSONLoader(file_path='/content/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()





In [28]:
!pip install faiss-cpu
db = FAISS.from_documents(documents=documents, embedding=embedding)

Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m46.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: faiss-cpu
Successfully installed faiss-cpu-1.8.0


#### Step 2: Match indexed tables embedding to incoming query

In [30]:
retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})

In [34]:

matched_documents = retriever.get_relevant_documents(query=query)
matched_documents


[Document(page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "description": "The `reservations` table keeps track of all flight bookings, linking them to customers and flights and recording when the reservation was made and its current status.", "example_queries": ["Find all reservations made by a specific customer.", "Show all reservations for a particular flight.", "Retrieve reservations made within a specific time period.", "Which reservations have a particular status (e.g., confirmed, cancelled, etc.)?", "List the reservations in chronological order."]}', metadata={'source': '/content/tables.jsonl', 'seq_num': 33}),
 Document(page_content='{"dataset_name": "flight_reservations", "table_name": "transactions", "description": "The `transactions` table documents all payment transactions related to flight reservations, including the reservation they pertain to, the amount paid, and when the transaction took place.", "example_queries": ["Find all transact

In [32]:
matched_tables = []

for document in matched_documents:
    page_content = document.page_content
    page_content = json.loads(page_content)
    dataset_name = page_content['dataset_name']
    table_name = page_content['table_name']
    matched_tables.append(f'{dataset_name}.{table_name}')

logger.info(f'Matched tables = {matched_tables}')


Matched tables = ['flight_reservations.reservations', 'flight_reservations.transactions', 'flight_reservations.flights', 'hotel_reservations.reservations', 'hotel_reservations.inventory']
INFO:langchain:Matched tables = ['flight_reservations.reservations', 'flight_reservations.transactions', 'flight_reservations.flights', 'hotel_reservations.reservations', 'hotel_reservations.inventory']


#### Step 3: Embed and index columns info


In [36]:
documents = JSONLoader(file_path='/content/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
db = FAISS.from_documents(documents=documents, embedding=embedding)

In [37]:
documents[0]

Document(page_content='{"dataset_name": "hotel_reservations", "table_name": "hotels", "column_name": "hotel_id", "description": "A unique identifier assigned to each hotel.", "usage": "This ID helps in maintaining a distinct record for each hotel and acts as a primary key. It\'s also used for referencing in other tables like Rooms.", "data_type": "INT64"}', metadata={'source': '/content/columns.jsonl', 'seq_num': 1})

#### Step 4: Match indexed columns embedding to incoming query

In [38]:
search_kwargs = {
    'k': 20
}

retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)

In [39]:
matched_columns = retriever.get_relevant_documents(query=query)
logger.info(matched_columns)

[Document(page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "column_name": "reservation_datetime", "description": "Timestamp of when the reservation was made.", "usage": "Helps track reservation history and manage bookings.", "data_type": "DATETIME"}', metadata={'source': '/content/columns.jsonl', 'seq_num': 62}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "departure_datetime", "description": "The departure time of the flight.", "usage": "Informs users and helps them plan their travel.", "data_type": "DATETIME"}', metadata={'source': '/content/columns.jsonl', 'seq_num': 55}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "origin", "description": "The departure location of the flight.", "usage": "Helps users find flights based on their travel plans.", "data_type": "STRING"}', metadata={'source': '/content/columns.jsonl', 'seq_num': 53}), 

In [51]:
matched_columns_filtered = []

# LangChain filters does not support multiple values at the moment
for i, column in enumerate(matched_columns):
    page_content = json.loads(column.page_content)
    dataset_name = page_content['dataset_name']
    if dataset_name == 'flight_reservations':
        matched_columns_filtered.append(page_content)

logger.info(matched_columns_filtered)


[{'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'reservation_datetime', 'description': 'Timestamp of when the reservation was made.', 'usage': 'Helps track reservation history and manage bookings.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'departure_datetime', 'description': 'The departure time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'origin', 'description': 'The departure location of the flight.', 'usage': 'Helps users find flights based on their travel plans.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'destination', 'description': 'The arrival location of the flight.', 'usage': 'Used to find flights and plan journeys.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservation

In [86]:
matched_columns_cleaned = []

for doc in matched_columns_filtered:
    dataset_name = doc['dataset_name']
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    matched_columns_cleaned.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')

matched_columns_cleaned = '\n'.join(matched_columns_cleaned)
logger.info(matched_columns_cleaned)

dataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=transaction_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=customers|column_name=created_at|data_type=DATETIME
dataset_name=flight_reservations|table_name=customers|column_name=date_of_birth|data_type=DATE
dataset_name=flight_reservations|table_name=reservations|column_name=status|data_type=STRING
dataset_name=flight_reservations|table_name=transactions|column_name=reservation_id|data_type=INT64
dataset_name=fl

#### Step 5: Text-to-SQL generation

In [88]:
messages = []

In [89]:
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)


In [91]:
human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA:
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above.
The goal is to determine the availability of hotels based on the provided info.

IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this.
IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
NOTE: when referencing table name use full name alongwith dataset name. Dataset name and table name should be seperated by dot
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed.
"""

In [92]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [93]:
chat_prompt = ChatPromptTemplate.from_messages(messages)

In [94]:
request = chat_prompt.format_prompt(query=query,
                                    matched_schema=matched_columns_cleaned).to_messages()

In [95]:
%%time

response = llm(request)
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(sql)

SELECT 
    r.reservation_id,
    r.reservation_datetime,
    r.status,
    f.departure_datetime,
    f.origin,
    f.destination,
    f.carrier
FROM 
    flight_reservations.reservations AS r
JOIN 
    flight_reservations.flights AS f
ON 
    r.flight_id = f.flight_id
WHERE 
    r.reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15';
INFO:langchain:SELECT 
    r.reservation_id,
    r.reservation_datetime,
    r.status,
    f.departure_datetime,
    f.origin,
    f.destination,
    f.carrier
FROM 
    flight_reservations.reservations AS r
JOIN 
    flight_reservations.flights AS f
ON 
    r.flight_id = f.flight_id
WHERE 
    r.reservation_datetime BETWEEN '2023-10-10' AND '2023-10-15';


CPU times: user 16.8 ms, sys: 985 µs, total: 17.8 ms
Wall time: 1.32 s


### Step 6: Execute the generated SQL query in BigQuery

In [96]:
bq = bigquery.Client()


In [97]:
df = bq.query(sql).to_dataframe()
df

Unnamed: 0,reservation_id,reservation_datetime,status,departure_datetime,origin,destination,carrier
0,6,2023-10-10 10:00:00,Confirmed,2023-11-25 06:00:00,SEA,JFK,United
1,6,2023-10-10 10:00:00,Confirmed,2023-11-25 06:00:00,SEA,JFK,United
2,7,2023-10-12 11:30:00,Confirmed,2023-11-27 20:00:00,JFK,MIA,American
3,7,2023-10-12 11:30:00,Confirmed,2023-11-27 20:00:00,JFK,MIA,American
