- This task involves converting text input into a structured representation and using it to generate a semantically correct SQL query that can be executed on a database.
- Model used will be PaLM 2 model from Google. 
- Our focal point will be a flight reservation system, a complex domain with intricate relationships and data structures. Within this system, we’ll explore various tables encompassing flight details, passenger information, booking history, and more, all hosted within BigQuery datasets. Flight systems encompass a vast array of interconnected tables and datasets, ranging from customer details and ticketing information to flight schedules and pricing metrics. Will consider 4 tables - reservations, customers, transactions, and flights. 

- Architectural Patterns: 
    - We outline five distinct patterns for implementing LLMs in SQL query generation.

#### Create Bigquery Tables

In [11]:
from google.cloud.exceptions import NotFound
from google.cloud import bigquery
import pandas as pd
import logging
import os

In [12]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [14]:
SERVICE_ACCOUNT_CREDENTIALS = 'gcp_bigquery.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_CREDENTIALS

client = bigquery.Client()

- Create dataset

In [15]:
dataset_id = f"{client.project}.flight_reservations"
logger.info(dataset_id)

fluent-buckeye-422006.flight_reservations
fluent-buckeye-422006.flight_reservations


In [16]:
try:
    dataset = client.get_dataset(dataset_id)
    logger.info(f"Dataset {dataset_id} already exists!")
except NotFound:
    dataset = bigquery.Dataset(dataset_id)
    dataset.location = "US"
    dataset = client.create_dataset(dataset)
    logger.info(f"Dataset {dataset_id} created.")

Dataset fluent-buckeye-422006.flight_reservations created.
Dataset fluent-buckeye-422006.flight_reservations created.


- Create tables
- Creating customers table - 

In [17]:
table_id = f"{dataset_id}.customers"
logger.info(table_id)

customers_schema = [
    bigquery.SchemaField("customer_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("first_name", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("last_name", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("email", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("date_of_birth", "DATE", mode="REQUIRED"),
    bigquery.SchemaField("created_at", "DATETIME", mode="REQUIRED"),
]

fluent-buckeye-422006.flight_reservations.customers
fluent-buckeye-422006.flight_reservations.customers


In [18]:
try:
    customers_table = client.get_table(table_id)
    logger.info(f"Table {table_id} already exists!")
except NotFound:
    customers_table = bigquery.Table(table_id, schema=customers_schema)
    customers_table = client.create_table(customers_table)  
    logger.info(f"Table {table_id} created.")

Table fluent-buckeye-422006.flight_reservations.customers created.
Table fluent-buckeye-422006.flight_reservations.customers created.


- Create flights table

In [19]:
table_id = f"{dataset_id}.flights"
logger.info(table_id)

flights_schema = [
    bigquery.SchemaField("flight_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("origin", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("destination", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("departure_datetime", "DATETIME", mode="REQUIRED"),
    bigquery.SchemaField("arrival_datetime", "DATETIME", mode="REQUIRED"),
    bigquery.SchemaField("carrier", "STRING", mode="REQUIRED"),
    bigquery.SchemaField("price", "FLOAT64", mode="REQUIRED"),
]

fluent-buckeye-422006.flight_reservations.flights
fluent-buckeye-422006.flight_reservations.flights


In [20]:
try:
    flights_table = client.get_table(table_id)
    logger.info(f"Table {table_id} already exists!")
except NotFound:
    flights_table = bigquery.Table(table_id, schema=flights_schema)
    flights_table = client.create_table(flights_table)  
    logger.info(f"Table {table_id} created.")

Table fluent-buckeye-422006.flight_reservations.flights created.
Table fluent-buckeye-422006.flight_reservations.flights created.


- Create reservations table

In [22]:
table_id = f"{dataset_id}.reservations"
logger.info(table_id)

reservations_schema = [
    bigquery.SchemaField("reservation_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("customer_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("flight_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("reservation_datetime", "DATETIME", mode="REQUIRED"),
    bigquery.SchemaField("status", "STRING", mode="REQUIRED"),
]

fluent-buckeye-422006.flight_reservations.reservations
fluent-buckeye-422006.flight_reservations.reservations


In [23]:
try:
    reservations_table = client.get_table(table_id)
    logger.info(f"Table {table_id} already exists!")
except NotFound:
    reservations_table = bigquery.Table(table_id, schema=reservations_schema)
    reservations_table = client.create_table(reservations_table)  
    logger.info(f"Table {table_id} created.")

Table fluent-buckeye-422006.flight_reservations.reservations created.
Table fluent-buckeye-422006.flight_reservations.reservations created.


- Create transactions table

In [24]:
table_id = f"{dataset_id}.transactions"
logger.info(table_id)

transactions_schema = [
    bigquery.SchemaField("transaction_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("reservation_id", "INT64", mode="REQUIRED"),
    bigquery.SchemaField("amount", "FLOAT64", mode="REQUIRED"),
    bigquery.SchemaField("transaction_datetime", "DATETIME", mode="REQUIRED"),
]

fluent-buckeye-422006.flight_reservations.transactions
fluent-buckeye-422006.flight_reservations.transactions


In [25]:
try:
    transactions_table = client.get_table(table_id)
    logger.info(f"Table {table_id} already exists!")
except NotFound:
    transactions_table = bigquery.Table(table_id, schema=transactions_schema)
    transactions_table = client.create_table(transactions_table)  
    logger.info(f"Table {table_id} created.")

Table fluent-buckeye-422006.flight_reservations.transactions created.
Table fluent-buckeye-422006.flight_reservations.transactions created.


- Populate tables

In [33]:
job_config = bigquery.LoadJobConfig(
    source_format=bigquery.SourceFormat.CSV, skip_leading_rows=1, autodetect=True,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE  #added to have truncate and insert load
)

with open('text2sql_data/flights.csv', "rb") as source_file:
    job = client.load_table_from_file(source_file, customers_table, job_config=job_config)
    
job.result()  # Waits for the job to complete.

table = client.get_table(customers_table)  # Make an API request.
print(
    "Loaded {} rows and {} columns to {}".format(
        table.num_rows, len(table.schema), customers_table
    )
)

Loaded 20 rows and 6 columns to fluent-buckeye-422006.flight_reservations.customers


In [35]:
job_config = bigquery.LoadJobConfig(
    source_format=bigquery.SourceFormat.CSV, skip_leading_rows=1, autodetect=True,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE  #added to have truncate and insert load
)

with open('text2sql_data/flights.csv', "rb") as source_file:
    job = client.load_table_from_file(source_file, flights_table, job_config=job_config)
    
job.result()  # Waits for the job to complete.

table = client.get_table(flights_table)  # Make an API request.
print(
    "Loaded {} rows and {} columns to {}".format(
        table.num_rows, len(table.schema), flights_table
    )
)

Loaded 20 rows and 7 columns to fluent-buckeye-422006.flight_reservations.flights


In [36]:
job_config = bigquery.LoadJobConfig(
    source_format=bigquery.SourceFormat.CSV, skip_leading_rows=1, autodetect=True,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE  #added to have truncate and insert load
)

with open('text2sql_data/reservations.csv', "rb") as source_file:
    job = client.load_table_from_file(source_file, reservations_table, job_config=job_config)
    
job.result()  # Waits for the job to complete.

table = client.get_table(reservations_table)  # Make an API request.
print(
    "Loaded {} rows and {} columns to {}".format(
        table.num_rows, len(table.schema), reservations_table
    )
)

Loaded 20 rows and 5 columns to fluent-buckeye-422006.flight_reservations.reservations


In [37]:
job_config = bigquery.LoadJobConfig(
    source_format=bigquery.SourceFormat.CSV, skip_leading_rows=1, autodetect=True,
        write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE  #added to have truncate and insert load
)

with open('text2sql_data/transactions.csv', "rb") as source_file:
    job = client.load_table_from_file(source_file, transactions_table, job_config=job_config)
    
job.result()  # Waits for the job to complete.

table = client.get_table(transactions_table)  # Make an API request.
print(
    "Loaded {} rows and {} columns to {}".format(
        table.num_rows, len(table.schema), transactions_table
    )
)

Loaded 16 rows and 4 columns to fluent-buckeye-422006.flight_reservations.transactions


#### Pattern 1: Intent Detection and Entity Recognition with Text-to-SQL

- The process of converting text to SQL queries begins with intent detection. This step is crucial as it discerns the user’s purpose from their query. Intent detection is typically approached as a multi-class classification problem, requiring a supervised learning model that is trained on a dataset balanced across all possible intents. The introduction to LLM has changed the landscape. These models can perform tasks like intent detection in a zero-shot or few-shot manner, eliminating the need for extensive training data.
- Another key component of traditional text-to-SQL systems is Named Entity Recognition (NER), which involves identifying and extracting entities from the user’s input.

1. Intent Detection:
    - The user’s query or utterance is input into the LLM, which determines the user’s intent. For example, the intent behind a query in the context of booking a flight reservation.
    - Example : 
        - Need all the bookings from 10th to 15th October 2023.,RETRIEVE_RESERVATIONS
        - Who made a reservation last Wednesday?,IDENTIFY_RECENT_CUSTOMERS

In [3]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import AIMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.chat_models import ChatVertexAI
from google.cloud import bigquery
import pandas as pd
import logging
import os

In [14]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


In [60]:
SERVICE_ACCOUNT_CREDENTIALS = 'gcp_bigquery.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_CREDENTIALS

In [103]:
PROJECT = 'fluent-buckeye-422006'
LOCATION = 'us-central1'
MODEL_NAME = 'chat-bison-001'

bq = bigquery.Client()

In [63]:
from langchain.chat_models import ChatOpenAI
os.environ["OPEN_API_KEY"] = ""
llm = ChatOpenAI(temperature=0, openai_api_key=os.environ["OPEN_API_KEY"])

  warn_deprecated(


In [64]:
user_query = "Provide a list of all flight reservations from October 10th to October 15th, 2023"

In [65]:
messages = []
examples = pd.read_csv('text2sql_data/few-shot/prompts_intent.csv')
examples.head()

Unnamed: 0,prompt,intent
0,Need all the bookings from 10th to 15th Octobe...,RETRIEVE_RESERVATIONS
1,Could you retrieve reservations for mid-Octobe...,RETRIEVE_RESERVATIONS
2,Let’s see all the October reservations from 10...,RETRIEVE_RESERVATIONS
3,Any reservations from 10/10/2023 to 15/10/2023?,RETRIEVE_RESERVATIONS
4,I’m looking for bookings between the second an...,RETRIEVE_RESERVATIONS


In [66]:
template = "You are a helpful assistant capable of detecting the intent behind a user's query."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [67]:
for _, row in examples.iterrows():
    prompt, completion = row
    human_message = HumanMessagePromptTemplate.from_template(prompt)
    messages.append(human_message)
    ai_message = AIMessagePromptTemplate.from_template(completion)
    messages.append(ai_message)

human_template = "{user_query}"
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [69]:
chat_prompt = ChatPromptTemplate.from_messages(messages)
request = chat_prompt.format_prompt(user_query=user_query).to_messages()

In [70]:
%%time 

response = llm(request)
intent = response.content.strip()
logger.info(intent)

  warn_deprecated(
RETRIEVE_RESERVATIONS
RETRIEVE_RESERVATIONS


CPU times: user 73.3 ms, sys: 55.3 ms, total: 129 ms
Wall time: 4.13 s


2. Entity Recognition : 
    - Extract the entities from the user query. To perform NER — the extraction of relevant entities from the text. 
    - Example: 
        - "Can you show me all the reservations from October 10th to October 15th, 2023?","Start Date:October 10th, 2023|End Date:October 15th, 2023"
        - "What bookings do we have from 10/10/2023 to 10/15/2023?","Start Date:10/10/2023|End Date:10/15/2023"

In [71]:
messages = []
examples = pd.read_csv('text2sql_data/few-shot/prompts_ner.csv')
examples.head()

Unnamed: 0,prompt,entities
0,Can you show me all the reservations from Octo...,"Start Date:October 10th, 2023|End Date:October..."
1,What bookings do we have from 10/10/2023 to 10...,Start Date:10/10/2023|End Date:10/15/2023
2,Show the reservations occurring between the se...,"Start Date:October 8th, 2023|End Date:October ..."
3,List all bookings that are happening from Octo...,"Start Date:October 10, 2023|End Date:October 1..."
4,Fetch the reservations from the second week of...,"Start Date:October 8th, 2023|End Date:October ..."


In [73]:
template = "You are a helpful assistant capable of performing named entity recognition."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

for _, row in examples.iterrows():
    prompt, completion = row
    human_message = HumanMessagePromptTemplate.from_template(prompt)
    messages.append(human_message)
    ai_message = AIMessagePromptTemplate.from_template(completion)
    messages.append(ai_message)

human_template = "{user_query} Standardize the date format to YYYY-MM-DD."
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [74]:
chat_prompt = ChatPromptTemplate.from_messages(messages)
request = chat_prompt.format_prompt(user_query=user_query).to_messages()

In [75]:
%%time 

response = llm(request)
entities = response.content.strip()
logger.info(entities)

Start Date:2023-10-10|End Date:2023-10-15
Start Date:2023-10-10|End Date:2023-10-15


CPU times: user 61.1 ms, sys: 4.98 ms, total: 66.1 ms
Wall time: 1.44 s


3. Mapping Intent to Database Tables: 
    - The detected intent is used to determine which database tables to query. While LLMs can assist in this step, it can also be programmatically treated as key-value mapping retrieval
    - Examples: 
        - RETRIEVE_RESERVATIONS,reservations|flights
        - IDENTIFY_RECENT_CUSTOMERS,reservations|customers
        - CALCULATE_REVENUE,reservations|transactions
        - FIND_PEAK_DEPARTURE_MONTHS,flights

In [76]:
messages = []
examples = pd.read_csv('text2sql_data/few-shot/intent_to_table_mapping.csv')
examples.head()

Unnamed: 0,intent,mapped_tables
0,RETRIEVE_RESERVATIONS,reservations|flights
1,IDENTIFY_RECENT_CUSTOMERS,reservations|customers
2,CALCULATE_REVENUE,reservations|transactions
3,FIND_PEAK_DEPARTURE_MONTHS,flights
4,GROUP_AND_COUNT_CUSTOMERS_BY_AGE,customers


In [77]:
template = "You are a helpful assistant capable of mapping detected intent to the correct list of BigQuery tables."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

for _, row in examples.iterrows():
    prompt, completion = row
    human_message = HumanMessagePromptTemplate.from_template(prompt)
    messages.append(human_message)
    ai_message = AIMessagePromptTemplate.from_template(completion)
    messages.append(ai_message)

human_template = "{user_intent}"
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [78]:
chat_prompt = ChatPromptTemplate.from_messages(messages)
request = chat_prompt.format_prompt(user_intent=intent).to_messages()

In [79]:
%%time 

response = llm(request)
tables = response.content.strip()
logger.info(tables)

reservations|flights
reservations|flights


CPU times: user 16 ms, sys: 2.9 ms, total: 18.9 ms
Wall time: 1.03 s


4. Load and filter table schemas: 

In [81]:
def read_files_from_dir(directory):
    if not os.path.exists(directory):
        logger.warn(f"The directory {directory} does not exist!")
        return {}

    # Create an empty dictionary to store filename and content
    files_dict = {}

    # Iterate over each file in the directory
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)

        # Ensure it's a file and not a sub-directory or other entity
        if os.path.isfile(file_path):
            with open(file_path, 'r', encoding='utf-8') as file:
                content = file.read()
                filename = filename.split('.txt')[0]
                files_dict[filename] = content

    return files_dict

In [88]:
directory_path = 'text2sql_data/text-schema/'
table_schemas = read_files_from_dir(directory_path)

In [89]:
table_names = tables.split('|')
filtered_table_schemas = {}

for table_name in table_names:
    if table_name in table_schemas.keys():
        filtered_table_schemas[table_name] = table_schemas[table_name]

filtered_table_schemas_text = []
for schema in filtered_table_schemas.values():
    filtered_table_schemas_text.append(schema)

filtered_table_schemas_text = ''.join(filtered_table_schemas_text)
logger.info(filtered_table_schemas_text)

----
Reservations Table:
Description:
The Reservations table keeps track of all flight reservations made by customers. Each record represents a unique reservation, detailing the customer, flight, reservation time, and status.
----
Columns:
--
reservation_id:
Description: A unique identifier for each reservation made on the platform.
Usage: This ID ensures that each reservation is distinct and can be referenced for customer inquiries, modifications, and operational tracking.
Type: INT64
--
customer_id:
Description: A reference to a customer from the Customers table who made the reservation.
Usage: Establishes which customer made a specific reservation, aiding in personalized user experiences, communication, and support.
Type: INT64
--
flight_id:
Description: Refers to a specific flight from the Flights table.
Usage: Ensures that the reservation corresponds to a specific flight, aiding in managing flight capacities and customer communications.
Type: INT64
--
reservation_datetime:
Descrip

5. SQL Statement Construction
    - The gathered information, which includes the user query, detected intent, and extracted entities, is integrated with mapped tables and filtered schema info, such as table and column descriptions. This comprehensive data info is compiled into a structured prompt for the LLM as shown below.
    - Please construct a SQL query using the information provided below: 
        - Input Parameters:
            - INTENT: {intent}
            - EXTRACTED_ENTITIES: {entities}
            - MAPPED_TABLES: {tables}

        - User Query:
            - {user_query}

        - Table Schemas:
            - {filtered_table_schemas_text}

        - Note: 
        -   Please prefix the table names with `flight_reservations`.

In [90]:
messages = []
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [91]:
human_template = """Please construct a SQL query using the information provided below:

Input Parameters:
-----------------
INTENT: {intent}
EXTRACTED_ENTITIES: {entities}
MAPPED_TABLES: {tables}

User Query:
-----------
{user_query}

Table Schemas:
--------------
{filtered_table_schemas_text}

Note: 
- Please prefix the table names with `flight_reservations`."""

In [92]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)
chat_prompt = ChatPromptTemplate.from_messages(messages)

In [93]:
request = chat_prompt.format_prompt(intent=intent, entities=entities, tables=tables, user_query=user_query,
                                    filtered_table_schemas_text=filtered_table_schemas_text).to_messages()

In [95]:
%%time 

response = llm(request)
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(sql)

SELECT flight_reservations.reservation_id, flight_reservations.customer_id, flight_reservations.flight_id, flight_reservations.reservation_datetime, flight_reservations.status
FROM flight_reservations.reservations AS flight_reservations
JOIN flight_reservations.flights AS flight_flights
ON flight_reservations.flight_id = flight_flights.flight_id
WHERE flight_reservations.reservation_datetime >= TIMESTAMP('2023-10-10') 
AND flight_reservations.reservation_datetime < TIMESTAMP('2023-10-16');
SELECT flight_reservations.reservation_id, flight_reservations.customer_id, flight_reservations.flight_id, flight_reservations.reservation_datetime, flight_reservations.status
FROM flight_reservations.reservations AS flight_reservations
JOIN flight_reservations.flights AS flight_flights
ON flight_reservations.flight_id = flight_flights.flight_id
WHERE flight_reservations.reservation_datetime >= TIMESTAMP('2023-10-10') 
AND flight_reservations.reservation_datetime < TIMESTAMP('2023-10-16');


CPU times: user 14.1 ms, sys: 3.56 ms, total: 17.6 ms
Wall time: 2.28 s


#### Pattern 2: LLM Flow with RAG

- If we’re dealing with hundreds or even thousands of tables in BigQuery, or instances where tables are broad with thousands of columns, the previous pattern doesn’t hold up. It’s practically impossible to map the detected intent to the corresponding tables. The same issue arises when we face new scenarios or query types, the first pattern can’t reliably map intent to the correct tables. 
- Here in RAG, We turn the table descriptions and schema descriptions (columns) into embeddings, then index these embeddings for search. 

1. Embedding and Indexing Descriptions: 
    - Start by encoding the table and column descriptions into embeddings using a text embedding model. Once the descriptions are encoded, create two indices — one for table descriptions and another for column descriptions. 

In [105]:
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.chat import HumanMessagePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from langchain.embeddings import VertexAIEmbeddings
from langchain.document_loaders import JSONLoader
from langchain.embeddings.base import Embeddings
from langchain.chat_models import ChatVertexAI
from langchain.vectorstores import FAISS
from google.cloud import bigquery
from typing import List
from tqdm import tqdm
import logging
import json
import os 

In [106]:
logger = logging.getLogger('langchain')
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [109]:
from langchain.embeddings import HuggingFaceBgeEmbeddings

model_name = "BAAI/bge-small-en-v1.5"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity

bge_embeddings = HuggingFaceBgeEmbeddings(
    model_name=model_name,
    model_kwargs={'device': 'cpu'},
    encode_kwargs=encode_kwargs
)

In [110]:
query = "Provide a list of all flight reservations from October 10th to October 15th, 2023"

In [113]:
documents = JSONLoader(file_path='text2sql_data/rag-schema/tables.jsonl', jq_schema='.', text_content=False, json_lines=True).load()

In [114]:
db = FAISS.from_documents(documents=documents, embedding=bge_embeddings)

2. 
    - The encoded query is compared against a pre-indexed repository of table descriptions to identify relevant tables and index of column descriptions that align with the user’s intent. With this a multi-layered search strategy is employed in this architectural pattern. 

In [115]:
retriever = db.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})
matched_documents = retriever.get_relevant_documents(query=query)

  warn_deprecated(


In [116]:
matched_documents

[Document(page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "description": "The `reservations` table keeps track of all flight bookings, linking them to customers and flights and recording when the reservation was made and its current status.", "example_queries": ["Find all reservations made by a specific customer.", "Show all reservations for a particular flight.", "Retrieve reservations made within a specific time period.", "Which reservations have a particular status (e.g., confirmed, cancelled, etc.)?", "List the reservations in chronological order."]}', metadata={'source': '/Users/nisargmehta/Documents/LLM/langchain/RAG/text2sql_data/rag-schema/tables.jsonl', 'seq_num': 33}),
 Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "description": "The `flights` table logs details about each flight, including origin, destination, departure and arrival times, the carrier, and the price.", "example_queries": ["Find all

In [118]:
documents = JSONLoader(file_path='text2sql_data/rag-schema/columns.jsonl', jq_schema='.', text_content=False, json_lines=True).load()
db = FAISS.from_documents(documents=documents, embedding=bge_embeddings)

In [119]:
search_kwargs = {'k': 20}

retriever = db.as_retriever(search_type='similarity', search_kwargs=search_kwargs)

In [120]:
matched_columns = retriever.get_relevant_documents(query=query)
logger.info(matched_columns)

[Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "arrival_datetime", "description": "The arrival time of the flight.", "usage": "Informs users and helps them plan their travel.", "data_type": "DATETIME"}', metadata={'source': '/Users/nisargmehta/Documents/LLM/langchain/RAG/text2sql_data/rag-schema/columns.jsonl', 'seq_num': 56}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "reservations", "column_name": "status", "description": "The status of the reservation (e.g., confirmed, cancelled).", "usage": "Informs users and staff of the current state of the reservation.", "data_type": "STRING"}', metadata={'source': '/Users/nisargmehta/Documents/LLM/langchain/RAG/text2sql_data/rag-schema/columns.jsonl', 'seq_num': 63}), Document(page_content='{"dataset_name": "flight_reservations", "table_name": "flights", "column_name": "departure_datetime", "description": "The departure time of the flight.", "usage": "In

In [121]:
matched_columns_filtered = []

# LangChain filters does not support multiple values at the moment
for i, column in enumerate(matched_columns):
    page_content = json.loads(column.page_content)
    dataset_name = page_content['dataset_name']
    if dataset_name == 'flight_reservations':
        matched_columns_filtered.append(page_content)

logger.info(matched_columns_filtered)

[{'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'arrival_datetime', 'description': 'The arrival time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'reservations', 'column_name': 'status', 'description': 'The status of the reservation (e.g., confirmed, cancelled).', 'usage': 'Informs users and staff of the current state of the reservation.', 'data_type': 'STRING'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'departure_datetime', 'description': 'The departure time of the flight.', 'usage': 'Informs users and helps them plan their travel.', 'data_type': 'DATETIME'}, {'dataset_name': 'flight_reservations', 'table_name': 'flights', 'column_name': 'destination', 'description': 'The arrival location of the flight.', 'usage': 'Used to find flights and plan journeys.', 'data_type': 'STRING'}, {'dataset_name': 'flight

In [122]:
matched_columns_cleaned = []

for doc in matched_columns_filtered:
    dataset_name = doc['dataset_name']
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    matched_columns_cleaned.append(f'dataset_name={dataset_name}|table_name={table_name}|column_name={column_name}|data_type={data_type}')
    
matched_columns_cleaned = '\n'.join(matched_columns_cleaned)
logger.info(matched_columns_cleaned)

dataset_name=flight_reservations|table_name=flights|column_name=arrival_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=reservations|column_name=status|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=departure_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=destination|data_type=STRING
dataset_name=flight_reservations|table_name=reservations|column_name=reservation_datetime|data_type=DATETIME
dataset_name=flight_reservations|table_name=flights|column_name=origin|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=carrier|data_type=STRING
dataset_name=flight_reservations|table_name=flights|column_name=price|data_type=FLOAT64
dataset_name=flight_reservations|table_name=customers|column_name=date_of_birth|data_type=DATE
dataset_name=flight_reservations|table_name=flights|column_name=flight_id|data_type=INT64
dataset_name=flight_reservations|table_name=reservati

3. Text-to-SQL generation

In [123]:
messages = []
template = "You are a SQL master expert capable of writing complex SQL query in BigQuery."
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
messages.append(system_message_prompt)

In [124]:
human_template = """Given the following inputs:
USER_QUERY:
--
{query}
--
MATCHED_SCHEMA: 
--
{matched_schema}
--
Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above. 
The goal is to determine the availability of hotels based on the provided info. 

IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. 
IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. 
"""

In [125]:
human_message = HumanMessagePromptTemplate.from_template(human_template)
messages.append(human_message)

In [126]:
chat_prompt = ChatPromptTemplate.from_messages(messages)
request = chat_prompt.format_prompt(query=query, matched_schema=matched_columns_cleaned).to_messages()

In [127]:
response = llm(request)
sql = '\n'.join(response.content.strip().split('\n')[1:-1])
logger.info(sql)

SELECT 
    r.reservation_id AS reservation_id,
    r.customer_id AS customer_id,
    r.flight_id AS flight_id,
    r.status AS reservation_status,
    r.reservation_datetime AS reservation_datetime,
    f.origin AS flight_origin,
    f.destination AS flight_destination,
    f.carrier AS flight_carrier,
    f.price AS flight_price,
    f.departure_datetime AS flight_departure_datetime,
    f.arrival_datetime AS flight_arrival_datetime
FROM 
    flight_reservations.reservations AS r
JOIN 
    flight_reservations.flights AS f
ON 
    r.flight_id = f.flight_id
WHERE 
    f.departure_datetime >= '2023-10-10' 
    AND f.arrival_datetime <= '2023-10-15'


#### Pattern 3: SQL Agent

- The agents have the capacity to analyze SQL databases by transforming natural language queries into SQL commands. The LangChain SQL agent for BigQuery exemplifies this capability, enabling users to interact with BigQuery databases using natural language. 

In [39]:
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.chat_models import ChatVertexAI
from langchain.sql_database import SQLDatabase
from langchain.agents import create_sql_agent
from urllib.parse import quote_plus
import pkg_resources
import sqlalchemy
import langchain
import logging
import os 

In [40]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())

In [41]:
PROJECT = 'fluent-buckeye-422006'
LOCATION = 'us-central1'
MODEL_NAME = 'codechat-bison@latest'
DATASET = 'flight_reservations'
SERVICE_ACCOUNT_KEY_PATH = 'gcp_bigquery.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY_PATH
SERVICE_ACCOUNT_KEY_PATH_ABS = os.path.abspath(SERVICE_ACCOUNT_KEY_PATH)
SERVICE_ACCOUNT_EMAIL = 'bigquery-database@fluent-buckeye-422006.iam.gserviceaccount.com'  # Can be found in the service account JSON key above
OAUTH_FLAG = 0 # 0 means we are using the service account credentials key for auth and 1 means we are using a refresh token 
CATALOG = 'fluent-buckeye-422006'  # same as project name 
DRIVER_PATH_ABS = '/Library/simba/googlebigqueryodbc/lib/libgooglebigqueryodbc_sbu.dylib'
EMAIL = 'bigquery-database@fluent-buckeye-422006.iam.gserviceaccount.com'

In [42]:
params = {
    'KeyFilePath': quote_plus(SERVICE_ACCOUNT_KEY_PATH_ABS),
    'Driver': quote_plus(DRIVER_PATH_ABS),
    'OAuthMechanism': OAUTH_FLAG,
    'Catalog': CATALOG,
    'Dataset': DATASET,
    'Email': EMAIL,
}

SQLALCHEMY_URI = f"bigquery://{PROJECT}/{DATASET}?{'&'.join(f'{k}={v}' for k, v in params.items())}"
logger.info(f'SQLALCHEMY_URI={SQLALCHEMY_URI}')

SQLALCHEMY_URI=bigquery://fluent-buckeye-422006/flight_reservations?KeyFilePath=%2FUsers%2Fnisargmehta%2FDocuments%2FLLM%2Flangchain%2FRAG%2Fgcp_bigquery.json&Driver=%2FLibrary%2Fsimba%2Fgooglebigqueryodbc%2Flib%2Flibgooglebigqueryodbc_sbu.dylib&OAuthMechanism=0&Catalog=fluent-buckeye-422006&Dataset=flight_reservations&Email=bigquery-database@fluent-buckeye-422006.iam.gserviceaccount.com
SQLALCHEMY_URI=bigquery://fluent-buckeye-422006/flight_reservations?KeyFilePath=%2FUsers%2Fnisargmehta%2FDocuments%2FLLM%2Flangchain%2FRAG%2Fgcp_bigquery.json&Driver=%2FLibrary%2Fsimba%2Fgooglebigqueryodbc%2Flib%2Flibgooglebigqueryodbc_sbu.dylib&OAuthMechanism=0&Catalog=fluent-buckeye-422006&Dataset=flight_reservations&Email=bigquery-database@fluent-buckeye-422006.iam.gserviceaccount.com
SQLALCHEMY_URI=bigquery://fluent-buckeye-422006/flight_reservations?KeyFilePath=%2FUsers%2Fnisargmehta%2FDocuments%2FLLM%2Flangchain%2FRAG%2Fgcp_bigquery.json&Driver=%2FLibrary%2Fsimba%2Fgooglebigqueryodbc%2Flib%2Flibg

In [43]:
from langchain.sql_database import SQLDatabase
db = SQLDatabase.from_uri(SQLALCHEMY_URI)



In [44]:
from langchain.chat_models import ChatOpenAI
os.environ["OPEN_API_KEY"] = ""
llm = ChatOpenAI(temperature=0, openai_api_key=os.environ["OPEN_API_KEY"])

- Scenario 1: Retrieve Active Reservations for a Specific Date Range

In [47]:
question = "Provide a list of all flight reservations from October 10th to October 15th, 2023"
agent_executor.run(question)



[1m> Entering new SQL Agent Executor chain...[0m


[32;1m[1;3mI need to query the database to retrieve flight reservations within the specified date range.
Action: sql_db_query_checker
Action Input: SELECT * FROM reservations WHERE reservation_date BETWEEN '2023-10-10' AND '2023-10-15'[0m[36;1m[1;3mSELECT * 
FROM reservations 
WHERE reservation_date >= '2023-10-10' AND reservation_date <= '2023-10-15'[0m[32;1m[1;3mThe query looks correct now, I can proceed to execute it.
Action: sql_db_query
Action Input: SELECT * 
FROM reservations 
WHERE reservation_date >= '2023-10-10' AND reservation_date <= '2023-10-15'[0m

  cursor = connection.execute(


[36;1m[1;3mError: (google.cloud.bigquery.dbapi.exceptions.DatabaseError) 400 POST https://bigquery.googleapis.com/bigquery/v2/projects/fluent-buckeye-422006/queries?prettyPrint=false: Unrecognized name: reservation_date; Did you mean reservation_datetime? at [3:7]
[SQL: SELECT * 
FROM reservations 
WHERE reservation_date >= '2023-10-10' AND reservation_date <= '2023-10-15']
(Background on this error at: https://sqlalche.me/e/14/4xp6)[0m[32;1m[1;3mI need to check the schema to see the correct field name for reservation date.
Action: sql_db_schema
Action Input: reservations[0m[33;1m[1;3m
CREATE TABLE `reservations` (
	`reservation_id` INT64, 
	`customer_id` INT64, 
	`flight_id` INT64, 
	`reservation_datetime` TIMESTAMP, 
	`status` STRING
)

/*
3 rows from reservations table:
reservation_id	customer_id	flight_id	reservation_datetime	status
4	3	4	2023-10-05 08:30:00+00:00	Cancelled
9	8	8	2023-10-20 09:00:00+00:00	Cancelled
13	13	12	2023-10-30 14:50:00+00:00	Cancelled
*/[0m[32;1m

"[(6, 6, 6, datetime.datetime(2023, 10, 10, 10, 0, tzinfo=datetime.timezone.utc), 'Confirmed'), (7, 6, 7, datetime.datetime(2023, 10, 12, 11, 30, tzinfo=datetime.timezone.utc), 'Confirmed')]"

- Scenario 2: Identify customers who made reservations in the past N days

In [48]:
question = "Identify all customers who have made flight reservations within the last 7 days."
_ = agent_executor.run(question)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI need to query the database to find customers who have made flight reservations within the last 7 days.
Action: sql_db_query_checker
Action Input: SELECT * FROM customers WHERE reservation_date >= DATE_SUB(NOW(), INTERVAL 7 DAY)[0m[36;1m[1;3mSELECT * 
FROM customers 
WHERE reservation_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY)[0m[32;1m[1;3mThe query looks correct, I should execute it to get the result.
Action: sql_db_query
Action Input: SELECT * FROM customers WHERE reservation_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY)[0m[36;1m[1;3mError: (google.cloud.bigquery.dbapi.exceptions.DatabaseError) 400 POST https://bigquery.googleapis.com/bigquery/v2/projects/fluent-buckeye-422006/queries?prettyPrint=false: Unrecognized name: reservation_date at [1:31]
[SQL: SELECT * FROM customers WHERE reservation_date >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY)]
(Background on this error at: https://sqlalche.me/e/14/4

- Scenario 3: Calculate Monthly Revenue

In [49]:
question = "Calculate the total revenue generated from transactions in October 2023, specifically from all reservations with a Confirmed status."
_ = agent_executor.run(question)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI need to retrieve the total revenue generated from transactions in October 2023 for reservations with a Confirmed status.
Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3mcustomers, flights, reservations, transactions[0m[32;1m[1;3mI need to find the tables that contain information about reservations and transactions.
Action: sql_db_schema
Action Input: reservations, transactions[0m[33;1m[1;3m
CREATE TABLE `reservations` (
	`reservation_id` INT64, 
	`customer_id` INT64, 
	`flight_id` INT64, 
	`reservation_datetime` TIMESTAMP, 
	`status` STRING
)

/*
3 rows from reservations table:
reservation_id	customer_id	flight_id	reservation_datetime	status
4	3	4	2023-10-05 08:30:00+00:00	Cancelled
9	8	8	2023-10-20 09:00:00+00:00	Cancelled
13	13	12	2023-10-30 14:50:00+00:00	Cancelled
*/


CREATE TABLE `transactions` (
	`transaction_id` INT64, 
	`reservation_id` INT64, 
	`amount` INT64, 
	`transaction_datetime` TIME

- Scenario 5: Customer Age Group
    - Group customers by age brackets and count the number in each bracket.

In [50]:
question = "Group customers into five distinct age brackets and count the number of customers in each bracket."
_ = agent_executor.run(question)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI need to group customers by age brackets and count the number of customers in each bracket.
Action: sql_db_query_checker
Action Input: SELECT COUNT(*), 
    CASE
        WHEN age < 20 THEN '0-19'
        WHEN age BETWEEN 20 AND 29 THEN '20-29'
        WHEN age BETWEEN 30 AND 39 THEN '30-39'
        WHEN age BETWEEN 40 AND 49 THEN '40-49'
        ELSE '50+'
    END AS age_bracket
FROM customers
GROUP BY age_bracket;[0m[36;1m[1;3mSELECT COUNT(*), 
    CASE
        WHEN age < 20 THEN '0-19'
        WHEN age BETWEEN 20 AND 29 THEN '20-29'
        WHEN age BETWEEN 30 AND 39 THEN '30-39'
        WHEN age BETWEEN 40 AND 49 THEN '40-49'
        ELSE '50+'
    END AS age_bracket
FROM customers
GROUP BY age_bracket;[0m[32;1m[1;3mThe query looks correct, I can now execute it to get the result.
Action: sql_db_query
Action Input: SELECT COUNT(*), 
    CASE
        WHEN age < 20 THEN '0-19'
        WHEN age BETWEEN 20 AND 29 TH

#### Pattern 4: Direct Schema Inference with Self-Correction

- This approach begins with the direct inference of the schema, utilizing a ‘seed prompt’ that instructs the LLM to construct an SQL query corresponding to a user’s inquiry. The execution of this initial prompt continues iteratively until it meets with success. Failures are treated as critical learning opportunity for the LLM, allowing it to scrutinize tracebacks and utilize error messages to refine and evolve the seed prompt into an improved query iteration. This self-corrected query is then tested; if successful, the process is complete; if not, the iterative cycle of refinement and testing is sustained until a successful outcome is secured.

In [45]:
from vertexai.language_models import CodeChatModel
from vertexai.language_models import CodeChatSession
from google.cloud import bigquery
import logging 
import os 

In [2]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [46]:
PROJECT_ID = 'fluent-buckeye-422006'
MODEL_NAME = 'codechat-bison@latest'
TEMPERATURE = 0 # default value = 0
MAX_OUTPUT_TOKENS = 2048  # length of the output response | overridding the default value which is 128
LOCATION = 'us-central1'

In [4]:
DATASET = 'flight_reservations'
TABLES = ['customers', 'flights', 'reservations', 'transactions', 'loyality_points']

In [48]:
SERVICE_ACCOUNT_CREDENTIALS = 'gcp_bigquery.json'
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_CREDENTIALS

from langchain.chat_models import ChatOpenAI
os.environ["OPEN_API_KEY"] = ""
llm = ChatOpenAI(temperature=0, openai_api_key=os.environ["OPEN_API_KEY"])

In [8]:
code_gen_model = llm
bq_client = bigquery.Client()

In [9]:
query = f"""
    SELECT *
    FROM `{PROJECT_ID}.{DATASET}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ({','.join([f'"{table}"' for table in TABLES])})
"""
logger.info(query)


    SELECT *
    FROM `fluent-buckeye-422006.flight_reservations.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS`
    WHERE table_name in ("customers","flights","reservations","transactions","loyality_points")



In [10]:
schema_columns = bq_client.query(query=query).to_dataframe()
schema_columns

Unnamed: 0,table_catalog,table_schema,table_name,column_name,field_path,data_type,description,collation_name,rounding_mode
0,fluent-buckeye-422006,flight_reservations,transactions,transaction_id,transaction_id,INT64,,,
1,fluent-buckeye-422006,flight_reservations,transactions,reservation_id,reservation_id,INT64,,,
2,fluent-buckeye-422006,flight_reservations,transactions,amount,amount,INT64,,,
3,fluent-buckeye-422006,flight_reservations,transactions,transaction_datetime,transaction_datetime,TIMESTAMP,,,
4,fluent-buckeye-422006,flight_reservations,reservations,reservation_id,reservation_id,INT64,,,
5,fluent-buckeye-422006,flight_reservations,reservations,customer_id,customer_id,INT64,,,
6,fluent-buckeye-422006,flight_reservations,reservations,flight_id,flight_id,INT64,,,
7,fluent-buckeye-422006,flight_reservations,reservations,reservation_datetime,reservation_datetime,TIMESTAMP,,,
8,fluent-buckeye-422006,flight_reservations,reservations,status,status,STRING,,,
9,fluent-buckeye-422006,flight_reservations,flights,flight_id,flight_id,INT64,,,


In [11]:
schema_columns = schema_columns.to_markdown(index=False)
logger.info(schema_columns)

| table_catalog         | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| fluent-buckeye-422006 | flight_reservations | transactions | transaction_id       | transaction_id       | INT64       |               | NULL             |                 |
| fluent-buckeye-422006 | flight_reservations | transactions | reservation_id       | reservation_id       | INT64       |               | NULL             |                 |
| fluent-buckeye-422006 | flight_reservations | transactions | amount               | amount               | INT64       |               | NULL             |                 |
| fluent-buckeye-422006 | flight_reservations | transactions | transaction_datetime | transaction_datetime | TIMESTAMP  

In [57]:
def generate_and_execute_sql(prompt, max_tries=8):
    """
    Generate an SQL query using the code_gen_model and execute it using bq_client.
    
    Args:
    - prompt (str): Prompt to provide to the model for generating SQL.
    - max_tries (int): Maximum number of attempts to generate and execute SQL.
    
    Returns:
    - dict: A dictionary containing the successful dataframe or error messages and prompt evolution.
    """
    
    tries = 0
    error_messages = []
    prompts = [prompt]
    df = None

    # chat_session = CodeChatSession(model=llm, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS)
    
    while tries < max_tries:
        logger.info(f'ATTEMPT: {tries+1}')
        try:
            # Predict SQL using the model
            response = llm(prompt)
            print(response)
            generated_sql_query = response.text
            generated_sql_query = '\n'.join(generated_sql_query.split('\n')[1:-1])
            logger.info('-' * 50)
            logger.info(generated_sql_query)
            logger.info('-' * 50)
            # Execute SQL using BigQuery client
            df = bq_client.query(generated_sql_query).to_dataframe()
            logger.info('SUCCEEDED')
            return {"dataframe": df, "prompts": prompts, "errors": error_messages}
        except Exception as e:
            logger.error('FAILED')
            # Catch the error, store the message, and try again
            msg = str(e)
            error_messages.append(msg)
            # Evolve the prompt by appending the error message and asking the model to correct it
            prompt = f"""Encountered an error: {msg}. 
To address this, please generate an alternative SQL query response that avoids this specific error. 
Follow the instructions mentioned above to remediate the error. 

Modify the below SQL query to resolve the issue:
{generated_sql_query}

Ensure the revised SQL query aligns precisely with the requirements outlined in the initial question."""
            prompts.append(prompt)
            tries += 1
        logger.info('=' * 100)

    return {"dataframe": df, "prompts": prompts, "errors": error_messages}

- Constructing the SEED prompt

In [58]:
seed_prompt = """
Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.\n
QUESTION:
{}\n
SCHEMA:
{}\n
IMPORTANT: 
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.
"""
logger.info(seed_prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
{}

SCHEMA:
{}

IMPORTANT: 
Use ONLY DATETIME and DO NOT use TIMESTAMP.
--
Ensure your SQL query accurately defines both the start and end of the DATETIME range.



In [59]:
question = "Provide a list of all reservations from October 10th to October 15th, 2023"
prompt = seed_prompt.format(question, schema_columns)
logger.info(prompt)


Please craft a SQL query for BigQuery that addresses the following QUESTION provided below. 
Ensure you reference the appropriate BigQuery tables and column names provided in the SCHEMA below. 
When joining tables, employ type coercion to guarantee data type consistency for the join columns. 
Additionally, the output column names should specify units where applicable.

QUESTION:
Provide a list of all reservations from October 10th to October 15th, 2023

SCHEMA:
| table_catalog         | table_schema        | table_name   | column_name          | field_path           | data_type   | description   | collation_name   | rounding_mode   |
|:----------------------|:--------------------|:-------------|:---------------------|:---------------------|:------------|:--------------|:-----------------|:----------------|
| fluent-buckeye-422006 | flight_reservations | transactions | transaction_id       | transaction_id       | INT64       |               | NULL             |                 |
| flu

In [63]:
# %%time

# response = generate_and_execute_sql(prompt=prompt)
# sql_output = response['dataframe']
# sql_output

### Project - NL2SQL

##### Building a Basic NL2SQL Model

In [72]:
 import os
 os.environ["OPENAI_API_KEY"] = ""

 db_user = "root"
 db_password = "Nisarg_12345"
 db_host = "localhost"
 db_name = "classicmodels"
 from langchain_community.utilities.sql_database import SQLDatabase
#  db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=1,include_tables=['customers','orders'],custom_table_info={'customers':"customer"})
 db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
 print(db.dialect)
 print(db.get_usable_table_names())
 print(db.table_info)

mysql
['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmit

In [76]:
os.environ['LANGCHAIN_TRACING_V2']="true"
os.environ['LANGCHAIN_API_KEY'] = "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"

- The query and response can be analyzed in langsmith (platform by langchain)

In [77]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
generate_query = create_sql_query_chain(llm, db) # this is chain
query = generate_query.invoke({"question": "what is price of `1968 Ford Mustang`"})
# "what is price of `1968 Ford Mustang`"
print(query)

SELECT `buyPrice`, `MSRP`
FROM products
WHERE `productName` = '1968 Ford Mustang'
LIMIT 1;


- Executing the generated SQL query against your database retrieves the data you're looking for, which LangChain can then present in a user-friendly format.

In [78]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db) # this is also chain
execute_query.invoke(query)

# Decimal('95.34') is buyPrice, Decimal('194.57') is MSRP

"[(Decimal('95.34'), Decimal('194.57'))]"

In [80]:
# Both chain can be combined - create_sql_query_chain and QuerySQLDataBaseTool
chain = generate_query | execute_query
chain.invoke({'question':"how many orders are there?"})

'[(326,)]'

- Rephrasing Answers for Enhanced Clarity: 

In [81]:
 from operator import itemgetter

 from langchain_core.output_parsers import StrOutputParser
 from langchain_core.prompts import PromptTemplate
 from langchain_core.runnables import RunnablePassthrough

 answer_prompt = PromptTemplate.from_template(
     """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

 Question: {question}
 SQL Query: {query}
 SQL Result: {result}
 Answer: """
 )

 rephrase_answer = answer_prompt | llm | StrOutputParser() # chain

# Combining generate_query | execute_query
# RunnablePassthrough takes the question (user query) and with "RunnablePassthrough.assign(query=generate_query)" it generate the SQL with generate_query function 
# Then this is passed as an input and with "itemgetter("query") | execute_query", we get the result in text format because the SQL query is passed as an input for 
# execute_query function which returns the value. As in rephrase_answer, for answer_prompt we need Question, SQL Query, SQL result : We have this 3, which is passed to LLM,
# and then to StrOutputParser for converting in string format.
 chain = (
     RunnablePassthrough.assign(query=generate_query).assign(
         result = itemgetter("query") | execute_query
     )
     | rephrase_answer
 )

 chain.invoke({"question": "How many customers have an order count greater than 5"})

'There are 2 customers who have an order count greater than 5.'

##### Few-Shot Examples & Dynamic Few-Shot Examples

In [84]:
examples = [
    {
        "input": "List all customers in France with a credit limit over 20,000.",
        "query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
    },
    {
        "input": "Get the highest payment amount made by any customer.",
        "query": "SELECT MAX(amount) FROM payments;"
    },
    {
        "input": "Show product details for products in the 'Motorcycles' product line.",
        "query": "SELECT * FROM products WHERE productLine = 'Motorcycles';"
    },
    {
        "input": "Retrieve the names of employees who report to employee number 1002.",
        "query": "SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;"
    },
    {
        "input": "List all products with a stock quantity less than 7000.",
        "query": "SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;"
    },
    {
     'input':"what is price of `1968 Ford Mustang`",
     "query": "SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;"   
    }
]

In [85]:
 from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

 example_prompt = ChatPromptTemplate.from_messages(
     [
         ("human", "{input}\nSQLQuery:"),
         ("ai", "{query}"),
     ]
 )
 few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     examples=examples,
     # input_variables=["input","top_k"],
     input_variables=["input"],
 )
 print(few_shot_prompt.format(input1="How many products are there?"))

Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: Retrieve the names of employees who report to employee number 1002.
SQLQuery:
AI: SELECT firstName, lastName FROM employees WHERE reportsTo = 1002;
Human: List all products with a stock quantity less than 7000.
SQLQuery:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: what is price of `1968 Ford Mustang`
SQLQuery:
AI: SELECT `buyPrice`, `MSRP` FROM products  WHERE `productName` = '1968 Ford Mustang' LIMIT 1;


- Dynamic Few-Shot Example Selection: 
    - This advanced technique tailors the few-shot examples provided to the model based on the specific context of the user's query. It ensures that the guidance offered to the model is not just relevant but optimally aligned with the query's nuances, significantly boosting the model's ability to generate accurate SQL queries. 

In [87]:
 from langchain_community.vectorstores import Chroma
 from langchain_core.example_selectors import SemanticSimilarityExampleSelector
 from langchain_openai import OpenAIEmbeddings

 vectorstore = Chroma()
 vectorstore.delete_collection()
 example_selector = SemanticSimilarityExampleSelector.from_examples(
     examples,
     OpenAIEmbeddings(),
     vectorstore,
     k=2,
     input_keys=["input"],
 )
 example_selector.select_examples({"input": "Who is the highest paying customer?"})

[{'input': 'Get the highest payment amount made by any customer.',
  'query': 'SELECT MAX(amount) FROM payments;'},
 {'input': 'List all customers in France with a credit limit over 20,000.',
  'query': "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"}]

In [89]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     example_selector=example_selector,
     input_variables=["input","top_k"],
 )
print(few_shot_prompt.format(input="How many products are there?"))

Human: List all products with a stock quantity less than 7000.
SQLQuery:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';


In [92]:
 final_prompt = ChatPromptTemplate.from_messages(
     [
         ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
         few_shot_prompt,
         ("human", "{input}"),
     ]
 )
 print(final_prompt.format(input="How many products are there?",table_info="some table info"))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries.
Human: List all products with a stock quantity less than 7000.
SQLQuery:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: How many products are there?


In [94]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
 RunnablePassthrough.assign(query=generate_query).assign(
     result=itemgetter("query") | execute_query
 )
 | rephrase_answer
)
chain.invoke({"question": "How many csutomers with credit limit more than 50000"})

'There are 85 customers with a credit limit greater than 50000.'

##### Dynamic Relevant Table Selection

- In the realm of NL2SQL models, especially when dealing with complex databases featuring 100+ tables. With databases growing in complexity and size, it's impractical and costly in terms of prompt token usage to include the schema of every table in the initial prompt for generating SQL queries. 
- Dynamic relevant table selection emerges as a solution to this challenge, focusing the model's attention only on the tables pertinent to the user's query

In [95]:
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    # Read the CSV file into a DataFrame
    table_description = pd.read_csv("data/database_table_descriptions.csv")
    table_docs = []

    # Iterate over the DataFrame rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

# table_names = "\n".join(db.get_usable_table_names())
table_details = get_table_details()
print(table_details)


Table Name:productlines
Table Description:Stores information about the different product lines offered by the company, including a unique name, textual description, HTML description, and image. Categorizes products into different lines.

Table Name:products
Table Description:Contains details of each product sold by the company, including code, name, product line, scale, vendor, description, stock quantity, buy price, and MSRP. Linked to the productlines table.

Table Name:offices
Table Description:Holds data on the company's sales offices, including office code, city, phone number, address, state, country, postal code, and territory. Each office is uniquely identified by its office code.

Table Name:employees
Table Description:Stores information about employees, including number, last name, first name, job title, contact info, and office code. Links to offices and maps organizational structure through the reportsTo attribute.

Table Name:customers
Table Description:Captures data on cus

In [96]:
table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables are:

{table_details}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

# create_extraction_chain_pydantic takes the table_details_prompt and llm that after extration gives the "Table" name
table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)
tables = table_chain.invoke({"input": "give me details of customer and their order count"})
tables

[Table(name='customers'), Table(name='orders')]

In [97]:
def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    return tables

select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
select_table.invoke({"question": "give me details of customer and their order count"})

['customers', 'orders']

In [98]:
# The generate query part wil check if to use all tables or to use the table mentioned in the variable "table_names_to_use". Both of them are passed in chain format.
chain = (
RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)
chain.invoke({"question": "How many cutomers with order count more than 5"})

'There are 2 customers with an order count of more than 5.'

##### Adding Memory for Follow-up Database Queries

In [99]:
 from langchain.memory import ChatMessageHistory
 history = ChatMessageHistory()

In [100]:
# MessagesPlaceholder will hold the conversation history
final_prompt = ChatPromptTemplate.from_messages(
     [
         ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions"),
         few_shot_prompt,
         MessagesPlaceholder(variable_name="messages"),
         ("human", "{input}"),
     ]
 )
print(final_prompt.format(input="How many products are there?",table_info="some table info",messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions
Human: List all products with a stock quantity less than 7000.
SQLQuery:
AI: SELECT productName, quantityInStock FROM products WHERE quantityInStock < 7000;
Human: Show product details for products in the 'Motorcycles' product line.
SQLQuery:
AI: SELECT * FROM products WHERE productLine = 'Motorcycles';
Human: How many products are there?


In [102]:
generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
 RunnablePassthrough.assign(table_names_to_use=select_table) |
 RunnablePassthrough.assign(query=generate_query).assign(
     result=itemgetter("query") | execute_query
 )
 | rephrase_answer
)

In [103]:
question = "How many cutomers with order count more than 5"
response = chain.invoke({"question": question,"messages":history.messages})

In [104]:
history.add_user_message(question)
history.add_ai_message(response)
history.messages

[HumanMessage(content='How many cutomers with order count more than 5'),
 AIMessage(content='There are 2 customers with an order count of more than 5.')]

In [106]:
response = chain.invoke({"question": "Can you list their names?","messages":history.messages})
response

"The names of the customers who have placed more than 5 orders are 'Mini Gifts Distributors Ltd.' and 'Euro+ Shopping Channel'."