# Best-practices for prompt engineering Text-to-SQL on Llama3

---

## Introduction

This notebook introduces a versatile approach that leverages Llama 3 models on Amazon Bedrock/Amazon SageMaker JumpStart, including advanced prompt engineering, to convert natural language questions into executable SQL queries. Our approach generates SQL queries capable of joining data from multiple tables, enabling information retrieval from complex database structures. This multi-table capability is crucial in real-world scenarios where data is often distributed across various tables with intricate relationships, and queries need to combine information from multiple sources to provide comprehensive insights.

Moreover, our approach demonstrates high scalability through dynamically selecting and retrieving the most relevant table schemas based on the given natural language question. This scalability is achieved by employing intelligent schema matching algorithms powered by ChromaDB. ChromaDB analyzes the question and automatically identifies the appropriate tables and relationships required to construct the SQL query, eliminating the need for manual intervention.

Our solution can be applied in practical scenarios where companies manage numerous databases with intricate table relationships, such as in the finance industry for analyzing customer transactions across multiple accounts and products, or in healthcare for integrating patient records from various systems and data sources.

---
## Llama 3 Model Selection

Today, there are two Llama 3 models available on Amazon Bedrock:

### 1. Llama 3 8B

- **Description:** Ideal for limited computational power and resources, faster training times, and edge devices.
- **Max Tokens:** 2,048
- **Context Window:** 8,196
- **Languages:** English
- **Supported Use Cases:** Synthetic Text Generation, Text Classification, and Sentiment Analysis.

### 2. Llama 3 70B

- **Description:** Ideal for content creation, conversational AI, language understanding, research development, and enterprise applications. 
- **Max Tokens:** 2,048
- **Context Window:** 8,196
- **Languages:** English
- **Supported Use Cases:** Synthetic Text Generation and Accuracy, Text Classification and Nuance, Sentiment Analysis and Nuance Reasoning, Language Modeling, Dialogue Systems, and Code Generation.

### Performance and Cost Trade-offs

The table below compares the model performance on the Massive Multitask Language Understanding (MMLU) benchmark and their on-demand pricing on Amazon Bedrock.

| Model           | MMLU Score | Price per 1,000 Input Tokens | Price per 1,000 Output Tokens |
|-----------------|------------|------------------------------|-------------------------------|
| Llama 3 8B | 68.4%      | \$0.0004                   | \$0.0006                    |
| Llama 3 70B | 82.0%      | \$0.00265                   | \$0.0035                     |

For more information, refer to the following links:

1. [Llama 3 8B Model Cards and Prompt Formats](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3)
2. [Amazon Bedrock Pricing Page](https://aws.amazon.com/bedrock/pricing/)


## The Approach to the Text-to-SQL Problem
This notebook covers the following approaches

### Few-shot text-to-SQL (Single Table vs Multiple Tables)
Few-shot text-to-SQL is an approach for querying databases by translating natural language questions into SQL queries, using only a few training examples.

Providing just a few examples of natural language questions paired with the equivalent SQL queries allows models to learn the mapping from natural language to SQL.

Reference : https://arxiv.org/abs/2305.12586

### Few-shot text-to-SQL powered by ChromaDB (Schema Retrieval vs Schema Retrieval powered by ChromaDB)

This approach leverages ChromaDB, a vector database, to assist the few-shot text-to-SQL translation process. ChromaDB can be used in two ways:

1. **Schema Retrieval**: In this method, we provide manually the table schema within the prompt. When a natural language question is provided, the model uses the schema information to aid in generating the SQL query.

2. **Schema Retrieval powered by ChromaDB**: In this method, ChromaDB stores the database schema information, which includes table names, column names, and their descriptions. When a natural language question is provided, the model can retrieve relevant schema information from ChromaDB to aid in generating the SQL query.

By leveraging ChromaDB, the few-shot text-to-SQL translation process can benefit from efficient schema and sample data retrieval, potentially leading to better performance and generalization across different databases and query types.

---

## Objectives

This notebook will provide code snippets to assist with implementing two differents approaches to converting a natural language question into a SQL query that would answer it.

---

## Contents

1. [Getting Started](#getting-started)
    + [Install Dependencies](#step-1-install-dependencies)
    + [Setup Bedrock and Database](#step-2-set-up-bedrock-client-and-database-connection)
    + [Build Database](#step-3-build-database)
    + [Create Helper Functions](#step-4-create-helper-functions)
1. [Few-Shot Text-to-SQL](#few-shot-text-to-sql)
1. [Analyzing a Single Table with Few-Shot Learning](#analyzing-a-single-table-with-few-shot-learning)
    + [Create a Few-Shot Prompt](#step-1-create-a-few-shot-prompt)
    + [Execute Few-Shot Prompts](#step-2-execute-few-shot-prompts)
1. [Analyzing Multiple Table with Few-Shot Learning](#analyzing-multiple-table-with-few-shot-learning)
    + [Create a Few-Shot Prompt](#step-1-create-a-few-shot-prompt)
    + [Execute Few-Shot Prompts](#step-2-execute-few-shot-prompts)
1. [Limitations of Few-Shot Learning](#limitation-of-few-shot-learning)
1. [Few-shot text-to-SQL powered by ChromaDB](#few-shot-text-to-sql-powered-by-chromadb)
1. [Schema Retrieval](#schema-retrieval)
    + [Data Preprocessing](#step-1-data-processing)
    + [Ingest docs into ChromaDB](#step-2-ingest-docs-into-chromadb)
    + [Create a Few-Shot Prompt](#step-3-create-a-few-shot-prompt)
    + [Execute Few-Shot Prompts](#step-4-execute-few-shot-prompts)
    + [Conclusion](#step-5-conclusion)
1. [Enhanced Schema Retrieval with an Embedding Model](#enhanced-schema-retrieval-with-an-embedding-model)
    + [Ingest docs into ChromaDB](#step-1-ingest-docs-into-chromadb)
    + [Execute Few-Shot Prompts](#step-2-execute-few-shot-prompts)
    + [Conclusion](#step-3-conclusion)
---

### Tools

+ AWS Python SDKs [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) to be able to submit API calls to [Amazon Bedrock](https://aws.amazon.com/bedrock/).

+ [LangChain](https://python.langchain.com/v0.1/docs/get_started/introduction/) is a framework that provides off the shelf components to make it easier to build applications with large language models. It is supported in multiple programming languages, such as Python, JavaScript, Java and Go. In this notebook, LangChain is used to build a prompt template.

+ [ChromaDB](https://www.trychroma.com/) is a vector database that enables efficient semantic search, storage, and retrieval of unstructured data like text, images, and audio. It's designed to work well with large language models (LLMs) and provides a simple and scalable way to build applications that can search and retrieve relevant information from vast amounts of data.

+ RDS (Relational Database Service) for [MySQL](https://aws.amazon.com/rds/mysql/) is a managed database service provided by Amazon Web Services (AWS). RDS for MySQL simplifies the setup, operation, and scaling of MySQL databases.
---

## Pre-requisites:

1. Use kernel either `conda_python3`, `conda_pytorch_p310` or `conda_tensorflow2_p310`.
2. Install the required packages.
3. Access to the LLM API. 

### Amazon Bedrock Deployment

In this notebook, Llama 3 70B is used. By deploying the notebook through our cloudformation template, it is granted the appropriate IAM permissions to send API request to Bedrock.

Refer [here](https://aws.amazon.com/blogs/aws/metas-llama-3-models-are-now-available-in-amazon-bedrock/) for details on how Amazon Bedrock provides access to Meta’s Llama 3.

### SageMaker Deployment

#### Changing instance type
---
Models are supported on the following instance types:

 - Llama3 8B Text Generation: `ml.g5.2xlarge`, `ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.g5.12xlarge`, `ml.g5.24xlarge`, `ml.g5.48xlarge`, and `ml.p4d.24xlarge`
- Llama3 70B Text Generation: `ml.g5.48xlarge`, and `ml.p4d.24xlarge`
 - BGE Large En v1.5: `ml.g5.2xlarge`, `ml.c6i.xlarge`,`ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.p3.2xlarge`, and `ml.g4dn.2xlarge`

By default, the JumpStartModel class selects a default instance type available in your region. If you would like to use a different instance type, you can do so by specifying instance type in the JumpStartModel class.

`my_model = JumpStartModel(model_id=model_id, instance_type="ml.g5.12xlarge")`

---

## Getting Started

### Step 0: Select Hosting Model Service

Here, you can select to run this notebook using SageMaker JumpStart or Amazon Bedrock.

In [1]:
def ask_for_service():
    service = input("Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK', '']:
        return 'Amazon Bedrock'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
llm_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the LLM for this notebook using {llm_selected_service}.")

Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B)  B


You have chosen to run the LLM for this notebook using Amazon Bedrock.


In [2]:
def ask_for_service():
    service = input("Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK', '']:
        return 'Amazon Bedrock'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
embedding_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the Embedding for this notebook using {embedding_selected_service}.")

Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B)  B


You have chosen to run the Embedding for this notebook using Amazon Bedrock.


In [3]:
if llm_selected_service == 'Amazon SageMaker':
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Specify the model ID for the HuggingFace Llama 3 Instruct LLM model
    llama3_8b_id = "meta-textgeneration-llama-3-8b-instruct"
    llama3_70b_id = "meta-textgeneration-llama-3-70b-instruct"
    DEFULT_LLM_MODEL_ID = llama3_70b_id
    if DEFULT_LLM_MODEL_ID == llama3_70b_id:
        instance_type = "ml.g5.48xlarge"
    else:
        instance_type = "ml.g5.12xlarge"
    model = JumpStartModel(model_id=DEFULT_LLM_MODEL_ID, instance_type=instance_type)
    llm_predictor = model.deploy(accept_eula=True)
    print(f"LLM SageMaker Endpoint Name: {llm_predictor.endpoint_name}")
else:
    llm_predictor = None
    llama3_8b_id = "meta.llama3-8b-instruct-v1:0"
    llama3_70b_id = "meta.llama3-70b-instruct-v1:0"
    DEFULT_LLM_MODEL_ID = llama3_70b_id
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

if embedding_selected_service == 'Amazon SageMaker':
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Deploy BGE Large En embedding model on Amazon SageMaker JumpStart:
    # Specify the model ID for the HuggingFace BGE Large EN Embedding model
    DEFAULT_EMBEDDING_MODEL_ID = "huggingface-sentencesimilarity-bge-large-en"
    text_embedding_model = JumpStartModel(model_id=DEFAULT_EMBEDDING_MODEL_ID)
    embedding_predictor = text_embedding_model.deploy()
    print(f"LLM SageMaker Endpoint Name: {embedding_predictor.endpoint_name}")
else:
    embedding_predictor = None
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

### Step 1: Install Dependencies

Here, we will install all the required dependencies to run this notebook.

In [4]:
!pip install boto3==1.34.127 -qU --force --quiet --no-warn-conflicts
!pip install mysql-connector-python==8.4.0 -qU --force --quiet --no-warn-conflicts
!pip install langchain==0.2.5 -qU --force --quiet --no-warn-conflicts
!pip install chromadb==0.5.0 -qU --force --quiet --no-warn-conflicts
!pip install numpy==1.26.4 -qU --force --quiet --no-warn-conflicts

**Note:** *When installing libraries using the pip, you may encounter errors or warnings during the installation process. These are generally not critical and can be safely ignored. However, after installing the libraries, it is recommended to restart the kernel or computing environment you are working in. Restarting the kernel ensures that the newly installed libraries are loaded properly and available for use in your code or workflow.*

#### Now lets import the required modules to run the notebook

In [5]:
import boto3
import chromadb
from chromadb.api.types import (
    Documents,
    EmbeddingFunction,
    Embeddings,
)
import json
from langchain import PromptTemplate
import mysql.connector as db
import re
from typing import Dict, List, Any
import yaml

In [6]:
# Setup Bedrock Client
bedrock_client = boto3.client(
    service_name='bedrock-runtime'
)

### Step 2: Set up database 

Here, we retrieve the services that are already deployed as a part of the cloudformation template to be used in building the application. The services include,

+ Secret ARN with RDS for MySQL Database credentials
+ Database Endpoint

In [7]:
stackname = "text2sql"  # If your stack name differs from "text2sql", please provide your stack name here.

In [8]:
cfn = boto3.client('cloudformation')

response = cfn.describe_stack_resources(
    StackName=stackname
)
cfn_outputs = cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']

# Get rds secret arn and database endpoint from cloudformation outputs
for output in cfn_outputs:
    if 'SecretArn' in output['OutputKey']:
        rds_secret_id = output['OutputValue']

    if 'DatabaseEndpoint' in output['OutputKey']:
        db_host = output['OutputValue']

In [9]:
secrets_client = boto3.client('secretsmanager')
credentials = json.loads(secrets_client.get_secret_value(SecretId=rds_secret_id)['SecretString'])

# Get password and username from secrets
db_password = credentials['password']
db_user = credentials['username']
db_name = "airline_db"

Establish the database connection

In [10]:
db_conn = db.connect(
    host=db_host,
    user=db_user,
    password=db_password
)

#### Use this section to check all the databases already in your test database. 

In [11]:
db_cursor = db_conn.cursor()

In [12]:
db_cursor.execute("SHOW DATABASES")

for tmp_db_name in db_cursor:
    print(tmp_db_name)

('information_schema',)
('mysql',)
('performance_schema',)
('sys',)


### Step 3: Build Database
Now the notebook will drop the test table and also the test database if it exists. It then proceeds with creation of the table.
Then it will insert test data for use in our prompting examples.

#### Load table schema settings

In [13]:
def load_settings(file_path):
    """
    Reads a YAML file and returns its contents as a Python object.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        obj: The contents of the YAML file as a Python object.
    """
    try:
        with open(file_path, 'r') as file:
            data = yaml.safe_load(file)
        return data
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' does not exist.")
    except yaml.YAMLError as exc:
        print(f"Error: Failed to parse the YAML file '{file_path}': {exc}")

In [14]:
# Load table settings
settings_airplanes = load_settings('schemas/airplanes.yml')
table_airplanes = settings_airplanes['table_name']
table_schema_airplanes = settings_airplanes['table_schema']

In [15]:
# Load table settings
settings_flights = load_settings('schemas/flights.yml')
table_flights = settings_flights['table_name']
table_schema_flights = settings_flights['table_schema']

In [16]:
# Load table settings
settings_airplane_flights = load_settings('schemas/airplanes-flights.yml')

#### Clean up database

In [17]:
# Delete flights' table
db_cursor.execute(f"DROP TABLE IF EXISTS {db_name}.{table_flights}")

In [18]:
# Delete airplanes' table
db_cursor.execute(f"DROP TABLE IF EXISTS {db_name}.{table_airplanes}")

In [19]:
# Delete database
db_cursor.execute(f"DROP DATABASE IF EXISTS {db_name}")

#### Create database and tables 

In [20]:
# Create database `airline_db`
db_cursor.execute(f"CREATE DATABASE {db_name}")

In [21]:
# Create table to hold data on fictitious airplanes information called `airplanes`
db_cursor.execute(table_schema_airplanes)

In [22]:
# Create table to hold data on fictitious flights information called `flights`
db_cursor.execute(table_schema_flights)

#### Read sample data

In [23]:
# Read sample data for the airplanes' table
with open('sample_data/airplanes.json', 'r') as f:
    data_airplanes = json.load(f)

In [24]:
# Read sample data for the flights' table
with open('sample_data/flights.json', 'r') as f:
    data_flights = json.load(f)

#### Ingest sample data into database

In [25]:
# Insert airplanes' data into database
for data in data_airplanes:
    sql = f"""
        INSERT INTO {db_name}.{table_airplanes} 
        (Airplane_id, Producer, Type) 
        VALUES (
        {data['Airplane_id']},
        '{data['Producer']}',
        '{data['Type']}'
        )
        """
    db_cursor.execute(sql)
db_conn.commit()

In [26]:
# Insert flights' data into database
for data in data_flights:
    sql = f"""
        INSERT INTO {db_name}.{table_flights}
        (Flight_number, Arrival_time, Arrival_date, Departure_time, Departure_date, Destination, Airplane_id) 
        VALUES (
        '{data['Flight_number']}',
        '{data['Arrival_time']}',
        '{data['Arrival_date']}',
        '{data['Departure_time']}',
        '{data['Departure_date']}',
        '{data['Destination']}',
        {data['Airplane_id']}
        )
        """
    db_cursor.execute(sql)
db_conn.commit()

Verify our database connection works and we can retrieve records from our table.

In [27]:
db_cursor.execute(f"SELECT * FROM {db_name}.{table_airplanes}")
sql_data = db_cursor.fetchall()

for record in sql_data:
    print(record)

(1, 'Boeing', '737')
(2, 'Airbus', 'A320')
(3, 'Embraer', 'E195')
(4, 'Bombardier', 'CRJ900')
(5, 'Boeing', '777')
(6, 'Airbus', 'A330')
(7, 'Embraer', 'E175')
(8, 'Bombardier', 'Q400')
(9, 'Boeing', '787')
(10, 'Airbus', 'A350')
(11, 'Embraer', 'E190')
(12, 'Bombardier', 'CRJ700')
(13, 'Boeing', '757')
(14, 'Airbus', 'A380')
(15, 'Embraer', 'E170')
(16, 'Bombardier', 'CRJ200')
(17, 'Boeing', '747')
(18, 'Airbus', 'A321')
(19, 'Embraer', 'E145')
(20, 'Bombardier', 'CRJ1000')


In [28]:
db_cursor.execute(f"SELECT * FROM {db_name}.{table_flights}")
sql_data = db_cursor.fetchall()

for record in sql_data:
    print(record)

('AA123', '2023-06-15T10:30:00', '2023-06-15', '2023-06-15T08:00:00', '2023-06-15', 'Los Angeles', 1)
('AA234', '2023-07-02T21:15:00', '2023-07-02', '2023-07-02T18:30:00', '2023-07-02', 'Tampa', 20)
('AA890', '2023-06-24T18:40:00', '2023-06-24', '2023-06-24T16:10:00', '2023-06-24', 'Atlanta', 5)
('AS345', '2023-06-19T21:00:00', '2023-06-19', '2023-06-19T18:30:00', '2023-06-19', 'Seattle', 7)
('AS789', '2023-06-27T15:50:00', '2023-06-27', '2023-06-27T13:20:00', '2023-06-27', 'Phoenix', 7)
('DL123', '2023-06-25T22:00:00', '2023-06-25', '2023-06-25T19:30:00', '2023-06-25', 'Las Vegas', 10)
('DL345', '2023-06-29T07:30:00', '2023-06-29', '2023-06-29T05:00:00', '2023-06-29', 'Philadelphia', 6)
('DL567', '2023-07-03T09:40:00', '2023-07-03', '2023-07-03T07:10:00', '2023-07-03', 'San Diego', 19)
('DL789', '2023-06-17T18:20:00', '2023-06-17', '2023-06-17T16:00:00', '2023-06-17', 'Miami', 10)
('DL901', '2023-06-21T13:20:00', '2023-06-21', '2023-06-21T10:50:00', '2023-06-21', 'Boston', 6)
('JB012'

### Step 4: Create helper functions

To facilate the usability and readability of the SQL Query Analysis made by Llama 3, we have developed a suite of helper functions tailored to various use cases.

The `format_instructions` function is designed to process the input from Llama 3 models, allowing a conversation between roles such as `system`, `user`, and `assistant`. To see more details about Llama 3 prompt formats, click [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/).

In [29]:
def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    """Format instructions where conversation roles must alternate system/user/assistant/user/assistant/..."""
    prompt: List[str] = []
    for instruction in instructions:
        if instruction["role"] == "system":
            prompt.extend(["<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
        elif instruction["role"] == "user":
            prompt.extend(["<|start_header_id|>user<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
        else:
            raise ValueError(f"Invalid role: {instruction['role']}. Role must be either 'user' or 'system'.")
    prompt.extend(["<|start_header_id|>assistant<|end_header_id|>\n"])
    return "".join(prompt)

The `execute_query` function will execute SQL queries, typically for retrieving data from a database, and format the results as a string for further processing or display. 

In [30]:
def execute_query(query: str) -> str:
    """Execute an SQL query on the database connection and return the results as a string.

    Args:
        query (str): SQL query to execute

    Returns:
        str: A formatted string containing the SQL results.
    """
    # Get a cursor from the database connection
    mycursor = db_conn.cursor()

    # Execute the SQL query
    mycursor.execute(query)

    # Fetch all result rows
    result_rows = mycursor.fetchall()

    # Convert result to string with newline between rows
    output_text = '\n'.join([str(x) for x in result_rows])
    return output_text

The `sagemaker_chat_completion` function uses the SageMaker Endpoint to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [31]:
def sagemaker_chat_completion(
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.5,
    top_p: float = 0.999
) -> str:
    """
    Generates a chat completion from a prompt using the llama3 model via Amazon SageMaker JumpStart.

    Args:
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_gen_len,
            "temperature": temperature,
            "top_p": top_p,
            "stop": ["<|eot_id|>"]
        }
    }

    # Call the model API to generate the completion
    response = llm_predictor.predict(body)
    completion = response.get('generated_text', '')

    return completion.strip()

The `bedrock_chat_completion` function uses the Bedrock client to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [32]:
def bedrock_chat_completion(
    model_id: str,
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.5,
    top_p: float = 0.999
) -> str:
    """
    Generates a chat completion from a prompt using the llama3 model via Amazon Bedrock.

    Args:
        model_id (str): The ID of the llama3 model to use for completion.
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "prompt": prompt,
        "max_gen_len": max_gen_len,
        "temperature": temperature,
        "top_p": top_p,
    }

    accept = "application/json"
    contentType = "application/json"

    # Convert the body dictionary to JSON string and encode it as bytes
    body_json = json.dumps(body)
    body_bytes = body_json.encode('utf-8')

    # Call the model API to generate the completion
    response = bedrock_client.invoke_model(
        body=body_bytes, modelId=model_id, accept=accept, contentType=contentType
    )
    response_body = response["body"].read()
    response_body = json.loads(response_body)
    completion = response_body.get("generation", "")

    return completion.strip()

The Function `get_llm_sql_analysis` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results.

In [33]:
def get_llm_sql_analysis(question: str, sql_sys_prompt: str, qna_sys_prompt: str) -> str:
    """
    Generates an SQL query based on the given question, executes it, and returns an analysis of the results using Llama 3.

    Args:
        question (str): The input question for which an SQL query needs to be generated.
        sql_sys_prompt (str): The prompt to be used for generating the SQL query using Llama 3.
        qna_sys_prompt (str): The prompt to be used for analyzing the SQL query results using Llama 3.

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """
    if llm_selected_service == 'Amazon SageMaker':
        # Generates SQL query
        completion = sagemaker_chat_completion(
            prompt=sql_sys_prompt
        )
    else:
        # Generates SQL query
        completion = bedrock_chat_completion(
            model_id=DEFULT_LLM_MODEL_ID,
            prompt=sql_sys_prompt
        )

    try:
        # Extract the SQL query
        pattern = r"<sql>(.*)</sql>"
        llm_sql_query = re.search(pattern, completion, re.DOTALL).group(1)
        print(f"LLM SQL Query: \n{llm_sql_query}")

        # Execute SQL query
        sql_results = execute_query(llm_sql_query)

        if llm_selected_service == 'Amazon SageMaker':
            # Generates SQL analysis
            llm_sql_analysis = sagemaker_chat_completion(
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )
        else:
            # Generates SQL analysis
            llm_sql_analysis = bedrock_chat_completion(
                model_id=DEFULT_LLM_MODEL_ID,
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )

        print(f"LLM SQL Analysis: \n{llm_sql_analysis}")
        return llm_sql_analysis
    except Exception as e:
        print(e)
        return e

The Class `AmazonBedrockEmbeddingFunction` initializes an embedding function with `Amazon Titan Embedding Model V2` that integrates with ChromaDB . This class can be further extended to add support for other embedding models available on Amazon Bedrock.

In [34]:
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        session: "boto3.Session",
        model_name: str = "amazon.titan-embed-text-v2:0",
        **kwargs: Any,
    ):
        """Initialize AmazonBedrockEmbeddingFunction.

        Args:
            session (boto3.Session): The boto3 session to use.
            model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
            **kwargs: Additional arguments to pass to the boto3 client.

        Example:
            >>> import boto3
            >>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
            >>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
            >>> texts = ["Hello, world!", "How are you?"]
            >>> embeddings = bedrock(texts)
        """

        self._model_name = model_name

        self._client = session.client(
            service_name="bedrock-runtime",
            **kwargs,
        )

    def __call__(self, input: Documents) -> Embeddings:
        accept = "*/*"
        content_type = "application/json"
        embeddings = []
        for text in input:
            input_body = {"inputText": text, "dimensions": 512, "normalize": True}
            body = json.dumps(input_body)
            response = self._client.invoke_model(
                body=body,
                modelId=self._model_name,
                accept=accept,
                contentType=content_type,
            )
            embedding = json.load(response.get("body")).get("embedding")
            embeddings.append(embedding)
        return embeddings

In [35]:
class AmazonSageMakerEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        **kwargs: Any,
    ):
        """Initialize AmazonSageMakerEmbeddingFunction.

        Args:
            **kwargs: Additional arguments to pass to the sagemaker embedding function.

        Example:
            >>> sagemaker = AmazonBedrockEmbeddingFunction()
            >>> text_inputs = ["Hello, world!", "How are you?"]
            >>> embeddings = sagemaker(texts)
        """


    def __call__(self, input: Documents) -> Embeddings:
        accept = "application/json"
        content_type = "application/json"

        embeddings = []
        for text in input:
            input_body = {"text_inputs": text, "mode": "embedding"}
            body = json.dumps(input_body).encode('utf-8')
            response = embedding_predictor.predict(
                body,
                {
                    "ContentType": content_type,
                    "Accept": accept,
                }
            )
            embedding = response.get("embedding")
            embeddings.append(embedding)
        return embeddings

## Few-Shot Text-to-SQL
With our database and tables filled with data, we're now ready to walk through the Few-Shot Text-to-SQL approach. We'll start by building some helper functions.

## Analyzing a Single Table with Few-Shot Learning

### Step 1: Create a Few-Shot Prompt
Here, we design our prompt template that will account for our question and answer, and formatted correctly for use with Llama 3 models.

First, we create a `system prompt` containing two parts:

1. `table_schema`. This is a description of the structure of the database table, including the name of the table, the names of the columns within each table, and the data types of each column. This information helps Llama 3 to understand the organization and contents of the table.

2. `question`. This is the specific request or information that the user wants to obtain from the table.

By including both the table schema and the user's question in the system prompt, we provide Llama 3 model a complete understanding of the table structure and the user's desired output.

In [36]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a mysql query expert whose output is a valid sql query.

Only use the following tables:

It has the following schema:
<table_schema>
{table_schema}
<table_schema>

Please construct a valid SQL statement to answer the following the question, return only the mysql query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]
tmp_sql_sys_prompt = format_instructions(instructions)

Next, we create a new `system prompt` containing two parts:

1. `query_results` represents the SQL query results after executing the prompt `tmp_sql_sys_prompt`. This is the raw data that Llama 3 model will use to generate its analysis.

2. `question`. This specifies the type of analysis or insight that the user wants Llama 3 model to provide based on the SQL query results.

By combining the SQL query results and the user's question into a single system prompt, we provide Llama 3 model all the information it needs to understand the context and provide a comprehensive analysis tailored to the user's request.

In [37]:
instructions = [
    {
        "role": "user",
        "content": """Given the following SQL query results:
{query_results}

And the original question:
{question}

Please provide an analysis and interpretation of the results to answer the original question.
"""
    }
]
QNA_SYS_PROMPT = format_instructions(instructions)

Building on our last prompt, we'll now add a Single Shot example to our context to better hint the model what we expect for a response.

### Step 2: Execute Few Shot Prompts
The following cells will demonstrate different questions asked in natural language and the SQL generated by the LLM. The output is contained between the `<sql>` and `</sql>` tags

In [38]:
# Business question
question = "What is the total count of airplanes?"

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schema=table_schema_airplanes
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

LLM SQL Query: 
SELECT COUNT(*) FROM airline_db.airplanes;
LLM SQL Analysis: 
A simple yet straightforward question!

Let's analyze the SQL query results:

The result is a single row with a single column containing the value `20`.

This suggests that the SQL query was designed to count the total number of airplanes in the database. The query likely used a `COUNT` aggregation function to tally up the number of rows in a table related to airplanes.

Given this result, we can confidently answer the original question:

**The total count of airplanes is 20.**


## Analyzing Multiple Table with Few-Shot Learning

### Step 1: Create a Few-Shot Prompt
Now, let's try the same approach using two tables.

First, we create a `system prompt` containing the same placeholders as before and including two table schemas. 

In [39]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a mysql query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
<table_schema>
{table_schema1}
<table_schema>

<table_schema>
{table_schema2}
<table_schema>
<table_schemas>

Please construct a valid SQL statement to answer the following the question, return only the mysql query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]
tmp_sql_sys_prompt = format_instructions(instructions)

### Step 2: Execute Few Shot Prompts

In [40]:
# Business question
question = "What is the total count of flights per producer?"

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schema1=table_schema_airplanes,
    table_schema2=table_schema_flights
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

LLM SQL Query: 

SELECT a.Producer, COUNT(f.Flight_number) AS Total_Flights
FROM airline_db.airplanes a
JOIN airline_db.flights f ON a.Airplane_id = f.Airplane_id
GROUP BY a.Producer;

LLM SQL Analysis: 
Based on the provided SQL query results, we can analyze and interpret the data to answer the original question:

**Original Question:** What is the total count of flights per producer?

**Results:**

1. Boeing: 4 flights
2. Airbus: 8 flights
3. Embraer: 5 flights
4. Bombardier: 3 flights

**Analysis and Interpretation:**

The results show the count of flights for each aircraft producer. To answer the original question, we can simply add up the counts for each producer to get the total count of flights:

**Total Count of Flights:** 4 (Boeing) + 8 (Airbus) + 5 (Embraer) + 3 (Bombardier) = **20 flights**

Therefore, the total count of flights per producer is 20 flights.

Additionally, we can observe the following insights from the results:

* Airbus has the highest number of flights (8), 

## Limitations of Few-Shot Learning

Few-Shot Learning for text-to-SQL tasks, where a language model is trained on a limited number of examples to translate natural language queries into SQL queries, faces significant limitations. One of the key challenges is selecting the appropriate table schema that aligns with the user's question.

In a real-world scenario, databases often consist of numerous tables with intricate relationships, making it difficult for the model to identify the relevant tables and columns required to answer a given query accurately. 

To address this issue, we propose incorporating ChromaDB to facilitate the retrieval of table schemas that are tailored to the user's question.

Here's how ChromaDB can help overcome the table schema selection challenge:

**Table Schema Retrieval**: Each table schema in the database can be converted into a dense vector embedding, capturing its structural information and relationships. The top-ranked table schemas are retrieved and provided as input to the text-to-SQL model, significantly increasing the likelihood of generating accurate SQL queries.

We will review this approach further in the next section.

## Few-shot text-to-SQL powered by ChromaDB

Here, we will use ChromaDB and the few-shot technique to effeciently retrieve table schemas for better performance and generalization across different databases and query types.

## Schema Retrieval

In this approach, we will store only the table schemas in ChromaDB.

### Step 1: Data Preprocessing

The first step is to preprocess the data and create a document that will be ingested into ChromaDB. The final doc clearly separates the table schemas by using XML tags such as `<table_schema></table_schema>`.

In [41]:
# The doc includes a structure format for clearly identifying the table schemas
doc1 = "<table_schemas>\n"
doc1 += f"<table_schema>\n {settings_airplanes['table_schema']} \n</table_schema>\n".strip()
doc1 += "\n</table_schemas>"
print(doc1)

<table_schemas>
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
</table_schemas>


In [42]:
# The doc includes a structure format for clearly identifying the table schemas
doc2 = "<table_schemas>\n"
doc2 += f"<table_schema>\n {settings_flights['table_schema']} \n</table_schema>\n".strip()
doc2 += "\n</table_schemas>"
print(doc2)

<table_schemas>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
</table_schemas>


In [43]:
# The doc includes a structure format for clearly identifying the table schemas
doc3 = "<table_schemas>\n"
for table_schema in settings_airplane_flights['table_schemas']:
    doc3 += f"<table_schema>\n {table_schema} \n</table_schema>\n"
doc3 += "\n</table_schemas>".strip()
print(doc3)

<table_schemas>
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
</table_schemas>


### Step 2: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB.

In [44]:
# Setup Chroma in-memory, for easy prototyping.
chroma_client = chromadb.Client()

In [45]:
# Create collection using ChromaDB's internal embedding function
collection1 = chroma_client.create_collection(name="table-schemas-default-embedding", metadata={"hnsw:space": "cosine"})

In [46]:
# Add docs to the collection.
collection1.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": db_name, "table_name": table_airplanes},
        {"source": "mysql", "database": db_name, "table_name": table_flights},
        {"source": "mysql", "database": db_name, "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)

/home/ec2-user/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 58.6MiB/s]


### Step 3: Create a Few-Shot Prompt

Now, we'll use a few-shot approach using the retrieved tables from ChromaDB.

First, we create a `system prompt` containing a placeholder including any number of table schemas for ChromaDB. 

In [47]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a mysql query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
{table_schemas}
<table_schemas>

Always combine the database name and table name to build your queries. You must identify these two values before proving a valid SQL query.

Please construct a valid SQL statement to answer the following the question, return only the mysql query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]
tmp_sql_sys_prompt = format_instructions(instructions)

### Step 4: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be used for the SQL llm analysis.

In [48]:
# Business question
question = "What is the total count of airplanes?"

# Query/search 1 most similar results.
docs = collection1.query(
    query_texts=[question],
    n_results=1
)

pattern = r"<table_schemas>(.*)</table_schemas>"
table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schemas=table_schemas,
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
LLM SQL Query: 
SELECT COUNT(DISTINCT Airplane_id) FROM airline_db.flights;
LLM SQL Analysis: 
Based on the SQL query results, we can see that the output is a single row with a single column containing the value 12.

This suggests that the SQL query was designed to count the total number of airplanes in the database, and the result is 12.

Therefore, the answer to the original question "What is the total count of airplanes?" is:

There a

### Step 5: Insights and Observations

We can observe that ChromaDB was unable to retrieve the correct table schema for the "airplanes" table. The issue arose due to a confusion caused by a foreign key reference. Specifically, ChromaDB retrieved the "flights" table instead of the "airplanes" table because the "flights" table contains a field called "Airplane_id" which references the "airplanes" table as a foreign key. This foreign key reference led to the confusion, resulting in ChromaDB retrieving the wrong table.

To mitigate this issue, we will use a more robust embedding model like the `Amazon Titan Embedding Model`.

We will see this in action in the following section.

## Enhanced Schema Retrieval with an Embedding Model

In this approach, we will use an embedding model from AWS.

### Step 1: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB using your selected embedding model (`Amazon SageMaker BGE Large English` or `Amazon Titan Embedding V2`).

In [49]:
# Create embedding function with AWS
if embedding_selected_service == "Amazon SageMaker":
    aws_ef = AmazonSageMakerEmbeddingFunction()
else:
    session = boto3.Session()
    aws_ef = AmazonBedrockEmbeddingFunction(
        session=session,
        model_name=DEFAULT_EMBEDDING_MODEL_ID
    )

In [50]:
# Create collection using Amazon Titan Embedding model
collection2 = chroma_client.create_collection(name="table-schemas-aws-embedding-model", embedding_function=aws_ef, metadata={"hnsw:space": "cosine"})

In [51]:
# Add docs to the collection.
collection2.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": db_name, "table_name": table_airplanes},
        {"source": "mysql", "database": db_name, "table_name": table_flights},
        {"source": "mysql", "database": db_name, "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)

### Step 2: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [52]:
# Business question
question = "What is the total count of airplanes?"

# Query/search 1 most similar results.
docs = collection2.query(
    query_texts=[question],
    n_results=1
)

pattern = r"<table_schemas>(.*)</table_schemas>"
table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schemas=table_schemas,
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
LLM SQL Query: 
SELECT COUNT(*) FROM airline_db.airplanes;
LLM SQL Analysis: 
Based on the SQL query results, we can see that the output is a single value: `(20)`.

This suggests that the SQL query was designed to count the total number of airplanes in a database table. The result `(20)` indicates that there are **20 airplanes** in the table.

Therefore, the answer to the original question "What is the total count of airplanes?" is **20**.


For this second example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [53]:
# Business question
question = "How many unique airplane producers are represented in the database?"

# Query/search 1 most similar results.
docs = collection2.query(
    query_texts=[question],
    n_results=1
)

pattern = r"<table_schemas>(.*)</table_schemas>"
table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schemas=table_schemas,
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
LLM SQL Query: 
SELECT COUNT(DISTINCT Producer) FROM airline_db.airplanes;
LLM SQL Analysis: 
Based on the SQL query results, which show a single value of `(4)`, we can conclude that there are **4 unique airplane producers** represented in the database.

This result suggests that the database contains information about airplanes from 4 distinct manufacturers or producers. This could be useful information for various purposes, such as market analysis, sales tracking, or maintenance planning.

In summary, the answer to the original question is: there are 4 unique airplane producers represented in the database.


For this third example, we expect the table `flights` to be included for the SQL llm analysis.

In [54]:
# Business question
question = "Get the total number of flights scheduled for each destination, grouped by arrival date"

# Query/search 1 most similar results.
docs = collection2.query(
    query_texts=[question],
    n_results=1
)

pattern = r"<table_schemas>(.*)</table_schemas>"
table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schemas=table_schemas,
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
LLM SQL Query: 

SELECT 
    Arrival_date, 
    Destination, 
    COUNT(Flight_number) as total_flights
FROM 
    airline_db.flights
GROUP BY 
    Arrival_date, 
    Destination
ORDER BY 
    Arrival_date;

LLM SQL Analysis: 
Based on the provided SQL query results, we can analyze and interpret the data to answer the original question.

**Results Analysis:**

The results show a list of 30 records, each representing a single flight schedu

For this fourth example, we expect the table `airplanes` and `flights` to be included for the SQL llm analysis.

In [55]:
# Business question
question = "Find the airplane IDs and producers for airplanes that have flown to New York"

# Query/search 1 most similar results.
docs = collection2.query(
    query_texts=[question],
    n_results=1
)

pattern = r"<table_schemas>(.*)</table_schemas>"
table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")

# Generate a prompt to get the LLM to provide an SQL query
SQL_SYS_PROMPT = PromptTemplate.from_template(tmp_sql_sys_prompt).format(
    question=question,
    table_schemas=table_schemas,
)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT,
    qna_sys_prompt=QNA_SYS_PROMPT
)

ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
LLM SQL Query: 

SELECT a.Airplane_id, a.Producer
FROM airline_db.airplanes a
JOIN airline_db.flights f ON a.Airplane_id = f.Airplane_id
WHERE f.Destination = 'New York';

LLM SQL Analysis: 
Based on the provided

## Conclusion 

We can observe that ChromaDB and `Amazon Titan Embedding` model were able to retrieve the correct table schemas for the previous examples.  After successfully implementing these solutions, the issue of incorrectly retrieved table schemas due to foreign key confusions was effectively addressed. The data retrieval process became more accurate and reliable, ensuring that the correct table schemas were consistently retrieved, even in the presence of complex table relationships and foreign key references.

## Clean Up Resources

In [None]:
# Delete resources
if llm_selected_service == 'Amazon SageMaker':
    llm_predictor.delete_model()
    llm_predictor.delete_endpoint()

if embedding_selected_service == 'Amazon SageMaker':
    embedding_predictor.delete_model()
    embedding_predictor.delete_endpoint()

# Thank you!