# Best-practices for prompt engineering Text-to-SQL on Llama3.1
---

## Introduction

This notebook introduces a versatile approach that leverages Llama 3.1 models on Amazon Bedrock/Amazon SageMaker JumpStart, including advanced prompt engineering, to convert natural language questions into executable SQL queries. Our approach generates SQL queries capable of joining data from multiple tables, enabling information retrieval from complex database structures. This multi-table capability is crucial in real-world scenarios where data is often distributed across various tables with intricate relationships, and queries need to combine information from multiple sources to provide comprehensive insights.

Moreover, our approach demonstrates high scalability through dynamically selecting and retrieving the most relevant table schemas based on the given natural language question. This scalability is achieved by employing intelligent schema matching algorithms powered by ChromaDB. ChromaDB analyzes the question and automatically identifies the appropriate tables and relationships required to construct the SQL query, eliminating the need for manual intervention.

Our solution can be applied in practical scenarios where companies manage numerous databases with intricate table relationships, such as in the finance industry for analyzing customer transactions across multiple accounts and products, or in healthcare for integrating patient records from various systems and data sources.


---
## Llama 3.1 Model Selection

There are THREE Llama 3.1 models available on Amazon Bedrock:

### 1. Llama 3.1 8B

- **Description:** Ideal for limited computational power and resources, faster training times, and edge devices. The model excels at text summarization, text classification, sentiment analysis, and language translation.
- **Context Window:** 128k
- **Languages:** English, German, French, Italian, Portuguese, Hindi, Spanish, and Thai.
- **Supported Use Cases:** Synthetic Text Generation, Text Classification, and Sentiment Analysis.

### 2. Llama 3.1 70B

- **Description:** Ideal for content creation, conversational AI, language understanding, research development, and enterprise applications. The model excels at text summarization and accuracy, text classification and nuance, sentiment analysis and nuance reasoning, language modeling, dialogue systems, code generation, and following instructions.
- **Context Window:** 128k
- **Languages:** English, German, French, Italian, Portuguese, Hindi, Spanish, and Thai.
- **Supported Use Cases:** Synthetic Text Generation and Accuracy, Text Classification and Nuance, Sentiment Analysis and Nuance Reasoning, Language Modeling, Dialogue Systems, and Code Generation.

### 2. Llama 3.1 405B

- **Description:** Ideal for enterprise level applications, research and development, synthetic data generation, and model distillation. The model excels at general knowledge, long-form text generation, machine translation, enhanced contextual understanding, advanced reasoning and decision making, better handling of ambiguity and uncertainty, increased creativity and diversity, steerability, math, tool use, multilingual translation, and coding.
- **Context Window:** 128k
- **Languages:** English, German, French, Italian, Portuguese, Hindi, Spanish, and Thai.
- **Supported Use Cases:** Synthetic Text Generation and Accuracy, Text Classification and Nuance, Sentiment Analysis and Nuance Reasoning, Language Modeling, Dialogue Systems, and Code Generation.


### Performance and Cost Trade-offs

The table below summarizes the model performance on the Massive Multitask Language Understanding ([MMLU](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md#instruction-tuned-models)) benchmark and their on-demand pricing on Amazon Bedrock.

| Model           | MMLU Score | Price per 1,000 Input Tokens | Price per 1,000 Output Tokens |
|-----------------|------------|------------------------------|-------------------------------|
| Llama 3.1 8B | 69.4%      | \$0.0003                   | \$0.0006                    |
| Llama 3.1 70B | 83.6%      | \$0.00265                   | \$0.0035                     |
| Llama 3.1 405B | 87.3%      | \$0.00532                   | \$0.016                     |

For more information, refer to the following links:

1. [Llama 3.1 Model Cards and Prompt Formats](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1)
2. [Amazon Bedrock Pricing Page](https://aws.amazon.com/bedrock/pricing/)

---

## The Approach to the Text-to-SQL Problem
This notebook covers the following approaches

### Few-shot text-to-SQL (Single Table vs Multiple Tables)
Few-shot text-to-SQL is an approach for querying databases by translating natural language questions into SQL queries, using only a few training examples. Providing just a few examples of natural language questions paired with the equivalent SQL queries allows models to learn the mapping from natural language to SQL. Reference : https://arxiv.org/abs/2305.12586

### Few-shot text-to-SQL powered by ChromaDB (Schema Retrieval vs Enhance Schema Retrieval with Sample Questions)

This approach leverages ChromaDB, a vector database, to assist the few-shot text-to-SQL translation process. ChromaDB can be used in two ways:

1. **Schema Retrieval**: In this method, we provide manually the table schema within the prompt. When a natural language question is provided, the model uses the schema information to aid in generating the SQL query.

2. **Schema Retrieval powered by ChromaDB**: In this method, ChromaDB stores the database schema information, which includes table names, column names, and their descriptions. When a natural language question is provided, the model can retrieve relevant schema information from ChromaDB to aid in generating the SQL query.

By leveraging ChromaDB, the few-shot text-to-SQL translation process can benefit from efficient schema and sample data retrieval, potentially leading to better performance and generalization across different databases and query types.

---

## Objectives

This notebook will provide code snippets to assist with implementing two differents approaches to converting a natural language question into a SQL query that would answer it.

For accessing LLMs the following code snippets are provided:
1. [Amazon Bedrock Invoke API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html)
2. [Amazon Bedrock Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html)
3. [SageMaker Jumpstart](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.jumpstart.model.JumpStartModel)

In addition, you have the option of running the SQL on MySQL or PostgreSQL database.

---

## Contents

1. [Getting Started](#Getting-Started)
    + [Install Dependencies](#Step-0:-Install-Dependencies)
    + [Select Model Hosting Service](#Step-1:-Select-Hosting-Model-Service)
    + [Get Database and Schema details](Step-2:-Get-Database-and-Schema-details)
    + [Create helper functions](Step-3:-Create-helper-functions)
1. [Few-Shot Text-to-SQL](#Few-Shot-Text-to-SQL)
1. [Analyzing a Single Table with Few-Shot Learning](#Analyzing-a-Single-Table-with-Few-Shot-Learning)
    + [Create a Few-Shot Prompt](#Step-1:-Create-a-Few-Shot-Prompt)
    + [Execute Few-Shot Prompts](#Step-2:-Execute-Few-Shot-Prompts)
1. [Analyzing Multiple Table with Few-Shot Learning](#Analyzing-Multiple-Table-with-Few-Shot-Learning)
    + [Create a Few-Shot Prompt](#Step-1:-Create-a-Few-Shot-Prompt)
    + [Execute Few-Shot Prompts](#Step-2:-Execute-Few-Shot-Prompts)
1. [Limitations of Few-Shot Learning](#Limitations-of-Few-Shot-Learning)
1. [Few-shot text-to-SQL powered by ChromaDB](#Few-shot-text-to-SQL-powered-by-ChromaDB)
1. [Schema Retrieval](#Schema-Retrieval)
    + [Data Preprocessing](#Step-1:-Data-Preprocessing)
    + [Ingest docs into ChromaDB](#Step-2:-Ingest-docs-into-ChromaDB)
    + [Create a Few-Shot Prompt](#Step-3:-Create-a-Few-Shot-Prompt)
    + [Execute Few-Shot Prompts](#Step-4:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-5-Conclusion)
1. [Enhanced Schema Retrieval with an Embedding Model](#Enhanced-Schema-Retrieval-with-an-Embedding-Model)
    + [Ingest docs into ChromaDB](#Step-1:-Ingest-docs-into-ChromaDB)
    + [Execute Few-Shot Prompts](#Step-2:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-3-Conclusion)
---

### Tools

+ AWS Python SDKs [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) to be able to submit API calls to [Amazon Bedrock](https://aws.amazon.com/bedrock/).

+ [LangChain](https://python.langchain.com/v0.1/docs/get_started/introduction/) is a framework that provides off the shelf components to make it easier to build applications with large language models. It is supported in multiple programming languages, such as Python, JavaScript, Java and Go. In this notebook, LangChain is used to build a prompt template.

+ [ChromaDB](https://www.trychroma.com/) is a vector database that enables efficient semantic search, storage, and retrieval of unstructured data like text, images, and audio. It's designed to work well with large language models (LLMs) and provides a simple and scalable way to build applications that can search and retrieve relevant information from vast amounts of data.

+ RDS (Relational Database Service) for [MySQL](https://aws.amazon.com/rds/mysql/) is a managed database service provided by Amazon Web Services (AWS). RDS for MySQL simplifies the setup, operation, and scaling of MySQL databases.

+ [PostgreSQL](https://aws.amazon.com/rds/postgresql/what-is-postgresql/) has become the preferred open source relational database for many enterprise developers and startups, powering leading business and mobile applications. With [Amazon RDS for PostgreSQL](https://aws.amazon.com/rds/postgresql/), you can deploy scalable PostgreSQL deployments in minutes with cost-efficient and resizable hardware capacity.
---

## Pre-requisites:

1. It is mandatory to have set up the database and sample data prior to using [this notebook](llama3-1-chromadb-text2sql-DB-Setup.ipynb).
2. Use kernel either `conda_python3`, `conda_pytorch_p310` or `conda_tensorflow2_p310`.
3. Install the required packages.
4. Access to the LLM API. 


### Amazon Bedrock Deployment

In this notebook, Llama 3.1 70B and 8B are used. You can easily switch between the two to evaluate the responses. By deploying the notebook through our cloudformation template, it is granted the appropriate IAM permissions to send API request to Bedrock. Refer [here](https://aws.amazon.com/blogs/aws/announcing-llama-3-1-405b-70b-and-8b-models-from-meta-in-amazon-bedrock/) for details on how Amazon Bedrock provides access to Meta’s Llama 3.1.

### SageMaker Deployment

#### Changing instance type
---
Models are supported on the following instance types:

 - Llama3.1 8B Text Generation: `ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.g5.12xlarge`, `ml.g5.24xlarge`, `ml.g5.48xlarge`, `ml.g6.4xlarge`, `ml.g6.8xlarge`, `ml.g6.12xlarge`, `ml.g6.24xlarge`, `ml.g6.48xlarge`, and `ml.p4d.24xlarge`
 - Llama3.1 70B Text Generation: `ml.g5.48xlarge`, `ml.g6.48xlarge`, `ml.p5.48xlarge` and `ml.p4d.24xlarge`
 - BGE Large En v1.5: `ml.g5.2xlarge`, `ml.c6i.xlarge`,`ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.p3.2xlarge`, and `ml.g4dn.2xlarge`

By default, the JumpStartModel class selects a default instance type available in your region. If you would like to use a different instance type, you can do so by specifying instance type in the JumpStartModel class.

`my_model = JumpStartModel(model_id=model_id, instance_type="ml.g5.12xlarge")`

---

## Getting Started

### Step 0: Install Dependencies

Here, we will install all the required dependencies to run this notebook.

In [1]:
!pip install boto3==1.34.127 -qU --force --quiet --no-warn-conflicts
!pip install mysql-connector-python==8.4.0 -qU --force --quiet --no-warn-conflicts
!pip install langchain==0.2.5 -qU --force --quiet --no-warn-conflicts
!pip install chromadb==0.5.0 -qU --force --quiet --no-warn-conflicts
!pip install numpy==1.26.4 -qU --force --quiet --no-warn-conflicts
!pip install psycopg2==2.9.9 -qU --force --quiet --no-warn-conflicts

**Note:** *When installing libraries using the pip, you may encounter errors or warnings during the installation process. These are generally not critical and can be safely ignored. However, after installing the libraries, it is recommended to restart the kernel or computing environment you are working in. Restarting the kernel ensures that the newly installed libraries are loaded properly and available for use in your code or workflow.*

<div class='alert alert-block alert-info'><b>NOTE:</b> Restart the kernel with the updated packages that are installed through the dependencies above</div>

In [None]:
# Restart the kernel
import os
os._exit(00)

#### Import the required modules to run the notbook

In [1]:
import boto3
import chromadb
from chromadb.api.types import (
    Documents,
    EmbeddingFunction,
    Embeddings,
)
import json
from langchain import PromptTemplate
import mysql.connector as MySQLdb
import re
from typing import Dict, List, Any
import yaml

import psycopg2 as PGdb
# from psycopg2 import sql
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT

In [2]:
# Setup Bedrock Client
bedrock_client = boto3.client(
    service_name='bedrock-runtime'
)

### Step 1: Select Hosting Model Service

Here, you can select to run this notebook using SageMaker JumpStart or Amazon Bedrock.

In [4]:
def ask_for_service():
    service = input("Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon Bedrock Converse API (C) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK (Invoke API)', '']:
        return 'Amazon Bedrock'
    elif service in  ['C', 'BEDROCK (Converse API)']:
        return 'Amazon Bedrock Converse API'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
llm_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the LLM for this notebook using [{llm_selected_service}].")

Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon Bedrock Converse API (C) or Amazon SageMaker JumpStart (S)? (default: B)  B


You have chosen to run the LLM for this notebook using [Amazon Bedrock].


In [5]:
def ask_for_service():
    service = input("Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK', '']:
        return 'Amazon Bedrock'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
embedding_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the Embedding for this notebook using [{embedding_selected_service}].")

Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B)  B


You have chosen to run the Embedding for this notebook using [Amazon Bedrock].


In [6]:
%%time
if llm_selected_service == 'Amazon SageMaker':
    
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Specify the model ID for the Meta Llama 3 Instruct LLM model
    llama3_1_8b_id = "meta-textgeneration-llama-3-1-8b-instruct"
    llama3_1_70b_id = "meta-textgeneration-llama-3-1-70b-instruct"

    DEFULT_LLM_MODEL_ID = llama3_1_8b_id
    
    model = JumpStartModel(model_id=DEFULT_LLM_MODEL_ID)
    
    llm_predictor = model.deploy(accept_eula=True)
    
    print(f"LLM SageMaker Endpoint Name: [{llm_predictor.endpoint_name}].")
else:
    llm_predictor = None
    
    llama3_1_8b_id = "meta.llama3-1-8b-instruct-v1:0"
    llama3_1_70b_id = "meta.llama3-1-70b-instruct-v1:0"

    DEFULT_LLM_MODEL_ID = llama3_1_8b_id
    
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

if embedding_selected_service == 'Amazon SageMaker':
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Deploy BGE Large EN embedding model on Amazon SageMaker JumpStart:
    # Specify the model ID for the HuggingFace BGE Large EN Embedding model
    DEFAULT_EMBEDDING_MODEL_ID = "huggingface-sentencesimilarity-bge-large-en"

    text_embedding_model = JumpStartModel(model_id=DEFAULT_EMBEDDING_MODEL_ID)
    
    embedding_predictor = text_embedding_model.deploy()
    
    print(f"LLM SageMaker Endpoint Name: [{embedding_predictor.endpoint_name}].")
else:
    embedding_predictor = None
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

CPU times: user 4 μs, sys: 0 ns, total: 4 μs
Wall time: 7.15 μs


In [7]:
def ask_for_db():
    service = input("Do you want to run the LLM example for this notebook using MySQL (M) or PostgreSQL (P)? (default: M) ").strip().upper()
    if service in ['M', 'MYSQL']:
        return 'mysql'
    elif service in ['P', 'POSTGRESQL', '']:
        return 'postgresql'
    else:
        print("Invalid input. Using MySQL by default.")
        return 'mysql'

# Call the function and get the selected DB
llm_selected_db = ask_for_db()

# Print the selected service
print(f"You have chosen to run the LLM example for this notebook using [{llm_selected_db}].")

Do you want to run the LLM example for this notebook using MySQL (M) or PostgreSQL (P)? (default: M)  M


You have chosen to run the LLM example for this notebook using [mysql].


### Step 2: Get Database and Schema details

Here, we retrieve the services that are already deployed as a part of the cloudformation template to be used in building the application. The services include,

+ Secret ARN with RDS for MySQL and RDS for PostgreSQL Database credentials
+ Database Endpoints

In [8]:
stackname = "text2sql"  # If your stack name differs from "text2sql", please modify.

In [9]:
cfn = boto3.client('cloudformation')

response = cfn.describe_stack_resources(
    StackName=stackname
)
cfn_outputs = cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']

# Get rds secret arn and database endpoint from cloudformation outputs
for output in cfn_outputs:
    if 'SecretArnMySQL' in output['OutputKey']:
        mySQL_secret_id = output['OutputValue']

    if 'DatabaseEndpointMySQL' in output['OutputKey']:
        mySQL_db_host = output['OutputValue']

    if 'SecretArnPG' in output['OutputKey']:
        pg_secret_id = output['OutputValue']

    if 'DatabaseEndpointPG' in output['OutputKey']:
        pg_db_host = output['OutputValue']

In [10]:
secrets_client = boto3.client('secretsmanager')

# Get MySQL credentials from Secrets Manager
credentials = json.loads(secrets_client.get_secret_value(SecretId=mySQL_secret_id)['SecretString'])

# Get password and username from secrets
mySQL_db_password = credentials['password']
mySQL_db_user = credentials['username']
mySQL_db_name = "airline_db"

# Get PostgreSQL credentials from Secrets Manager
credentials = json.loads(secrets_client.get_secret_value(SecretId=pg_secret_id)['SecretString'])

# Get password and username from secrets
pg_db_password = credentials['password']
pg_db_user = credentials['username']
pg_db_name = "airline_db"

Establish the database connection (MySQL DB)

In [11]:
mySQL_db_conn = MySQLdb.connect(
    host=mySQL_db_host,
    user=mySQL_db_user,
    password=mySQL_db_password
)

Establish the database connection (PostgresSQL DB)

In [12]:
pg_db_conn = PGdb.connect(
    host=pg_db_host,
    database=pg_db_name,
    user=pg_db_user,
    password=pg_db_password
)

# PostgreSQL DB Setup
pg_db_conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)

#### Load table schema settings

In [13]:
def load_settings(file_path):
    """
    Reads a YAML file and returns its contents as a Python object.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        obj: The contents of the YAML file as a Python object.
    """
    try:
        with open(file_path, 'r') as file:
            data = yaml.safe_load(file)
        return data
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' does not exist.")
    except yaml.YAMLError as exc:
        print(f"Error: Failed to parse the YAML file '{file_path}': {exc}")

In [14]:
# MySQL Table Setup

# Load table settings
settings_airplanes = load_settings('schemas/airplanes.yml')
table_airplanes = settings_airplanes['table_name']
table_schema_airplanes = settings_airplanes['table_schema']

# Load table settings
settings_flights = load_settings('schemas/flights.yml')
table_flights = settings_flights['table_name']
table_schema_flights = settings_flights['table_schema']

# Load table settings
settings_airplane_flights = load_settings('schemas/airplanes-flights.yml')

In [15]:
# PostgreSQL Table Setup

# Load table settings
pg_settings_airplanes = load_settings('schemas/airplanes-pg.yml')
pg_table_airplanes = pg_settings_airplanes['table_name']
pg_table_schema_airplanes = pg_settings_airplanes['table_schema']

# Load table settings
pg_settings_flights = load_settings('schemas/flights-pg.yml')
pg_table_flights = pg_settings_flights['table_name']
pg_table_schema_flights = pg_settings_flights['table_schema']

# Load table settings
pg_settings_airplane_flights = load_settings('schemas/airplanes-flights-pg.yml')

### Step 3: Create helper functions

To facilate the usability and readability of the SQL Query Analysis made by Llama 3.1, let's create a suite of helper functions.

##### Chat Completion (Invoke LLM and Return response)

The `sagemaker_chat_completion` function uses the SageMaker Endpoint to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [16]:
def sagemaker_chat_completion(
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.5,
    top_p: float = 0.999
) -> str:
    """
    Generates a chat completion from a prompt using the llama3 model via Amazon SageMaker JumpStart.

    Args:
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_gen_len,
            "temperature": temperature,
            "top_p": top_p,
            "stop": ["<|eot_id|>"]
        }
    }

    # Call the model API to generate the completion
    response = llm_predictor.predict(body)
    completion = response.get('generated_text', '')

    return completion.strip()

The `bedrock_chat_completion` function uses the Bedrock client to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [17]:
def bedrock_chat_completion(
    model_id: str,
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.5,
    top_p: float = 0.999
) -> str:
    """
    Generates a chat completion from a prompt using the llama3 model via Amazon Bedrock.

    Args:
        model_id (str): The ID of the llama3 model to use for completion.
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "prompt": prompt,
        "max_gen_len": max_gen_len,
        "temperature": temperature,
        "top_p": top_p,
    }

    accept = "application/json"
    contentType = "application/json"

    # Convert the body dictionary to JSON string and encode it as bytes
    body_json = json.dumps(body)
    body_bytes = body_json.encode('utf-8')

    # Call the model API to generate the completion
    response = bedrock_client.invoke_model(
        body=body_bytes, modelId=model_id, accept=accept, contentType=contentType
    )
    response_body = response["body"].read()
    response_body = json.loads(response_body)
    completion = response_body.get("generation", "")

    return completion.strip()

The Function `bedrock_converseapi_completion` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results. This uses the Amazon Bedrock Converse API.

In [18]:
def bedrock_converseapi_completion(
    model_id: str,
    conversation,
    max_tokens: int = 512,
    temperature: float = 0.5,
    top_p: float = 0.999,
    system_prompt=""
):
    """
    Generates a chat completion from a conversation using the llama3.1 model via 
    Amazon Bedrock Converse API.

    Args:
        model_id (str): The ID of the llama3 model to use for completion.
        conversation (object): The conversation so far user/assistant/user/...
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.
        system_prompt: Any instructions to regulate the LLM behaviour.

    Returns:
        response: Object containing the assistant's response
    """  
    try:
        if system_prompt == "":
            # Send the message to the model, using the provided inference configuration.
            response = bedrock_client.converse(
                modelId=model_id,
                messages=conversation,
                inferenceConfig={
                    "maxTokens": max_tokens,
                    "temperature": temperature,
                    "topP": top_p
                },
            )
        else:
            # Send the message to the model, using the provided inference configuration.
            response = bedrock_client.converse(
                modelId=model_id,
                messages=conversation,
                inferenceConfig={
                    "maxTokens": max_tokens,
                    "temperature": temperature,
                    "topP": top_p
                },
                system=[{"text": system_prompt}],
            )

        return response
    
    except (ClientError, Exception) as e:
        print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
        exit(1)

##### Instruction formatting, Query execution and LLM calling

The `format_instructions` function is designed to process the input from Llama 3 models, allowing a conversation between roles such as `system`, `user`, and `assistant`. To see more details about Llama 3 prompt formats, click [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/).

In [19]:
def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    """Format instructions where conversation roles must alternate system/user/assistant/user/assistant/...
    Formats the prompt. Returns as a native prompt for Llama and also returns the System and User Prompt 
    texts that can be used with Bedrock Converse API.

    Args:
        instructions (List): A Dictionary of user and system instructions
    Returns:
        str: A formatted string containing the prompt, plaintext system prompt and plaintext user prompt.
    """
    systemPrompt = userPrompt = ""
    prompt: List[str] = []
    
    for instruction in instructions:
        if instruction["role"] == "system":
            prompt.extend(["<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
            systemPrompt = instruction["content"].strip()
        elif instruction["role"] == "user":
            prompt.extend(["<|start_header_id|>user<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
            userPrompt = instruction["content"].strip()
        else:
            raise ValueError(f"Invalid role: {instruction['role']}. Role must be either 'user' or 'system'.")
    prompt.extend(["<|start_header_id|>assistant<|end_header_id|>\n"])
    return "".join(prompt), "".join(systemPrompt), "".join(userPrompt)

The `execute_query` function will execute SQL queries, typically for retrieving data from a database, and format the results as a string for further processing or display. 

In [20]:
def execute_query(query: str, db_conn) -> str:
    """Execute an SQL query on the database connection and return the results as a string.

    Args:
        query (str): SQL query to execute
        db_conn (Connection object): Connection object to the database where the query needs to be executed.

    Returns:
        str: A formatted string containing the SQL results.
    """
    # Get a cursor from the database connection
    mycursor = db_conn.cursor()

    # Execute the SQL query
    mycursor.execute(query)

    # Fetch all result rows
    result_rows = mycursor.fetchall()

    # Convert result to string with newline between rows
    output_text = '\n'.join([str(x) for x in result_rows])
    return output_text

The Function `get_llm_sql_analysis` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results.

In [21]:
def get_llm_sql_analysis(question: str, sql_sys_prompt: str, qna_sys_prompt: str, DatabaseType: str) -> str:
    """
    Generates an SQL query based on the given question, executes it, and returns an analysis of the results using Llama 3.

    Args:
        question (str): The input question for which an SQL query needs to be generated.
        sql_sys_prompt (str): The prompt to be used for generating the SQL query using Llama 3.
        qna_sys_prompt (str): The prompt to be used for analyzing the SQL query results using Llama 3.
        DatabaseType (enum): The type of database on which query needs to be executed. E.g. MySQL / PostgreSQL etc

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """

    try:
        # *****************************************************************************************
        # 1. Generate SQL Query
        # *****************************************************************************************
        print(f"\nUsing [{llm_selected_service}] to generate SQL for [{DatabaseType}]\n")
        
        if llm_selected_service == 'Amazon SageMaker':
            # Generates SQL query
            completion = sagemaker_chat_completion(
                prompt=sql_sys_prompt
            )
        elif llm_selected_service == 'Amazon Bedrock Converse API':
            # *****************************************************************************************
            # TODO: Perform additional steps for converse API
            conversation = [
                {
                    "role": "user",
                    "content": [{"text": question}],
                }
            ]
            
            response = bedrock_converseapi_completion(
                model_id=DEFULT_LLM_MODEL_ID, 
                conversation=conversation, 
                system_prompt=sql_sys_prompt)

            completion=response["output"]["message"]["content"][0]["text"]
            # *****************************************************************************************

        else:
            # Generates SQL query
            completion = bedrock_chat_completion(
                model_id=DEFULT_LLM_MODEL_ID,
                prompt=sql_sys_prompt
            )

        # *****************************************************************************************
        # 2. Extract SQL and Execute
        # *****************************************************************************************

        # Extract the SQL query from the completion returned from the first LLM call.
        pattern = r"<sql>(.*)</sql>"
        llm_sql_query = re.search(pattern, completion, re.DOTALL).group(1)
        print(f"\nLLM SQL Query: \n{llm_sql_query}")

        # Route the query according to the database passed. Connection object will vary.
        match DatabaseType:
            case "mysql":
                db_conn=mySQL_db_conn
            case "postgresql":
                db_conn=pg_db_conn

        # Execute SQL query based on the database provided.
        sql_results = execute_query(llm_sql_query, db_conn)

        print(f"\nsql_results = \n{sql_results}")

        # *****************************************************************************************
        # 3. Evaluate the response
        # *****************************************************************************************
        
        print(f"\nCalling LLM on [{llm_selected_service}] to Analyze and interpret the results from SQL query in relation to the original question.\n")
        
        # Now we will use a second LLM call to analyse the result of the query.
        if llm_selected_service == 'Amazon SageMaker':
            # Generates SQL analysis
            llm_sql_analysis = sagemaker_chat_completion(
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )
        elif llm_selected_service == 'Amazon Bedrock Converse API':
            # *****************************************************************************************
            # TODO: Append Assistant response.

            # Append user's next question
            # Create a conversation object
            queryAnalyseResults = qna_sys_prompt.format(query_results=sql_results, question=question)

            conversation = [
                {
                    "role": "user",
                    "content": [{"text": queryAnalyseResults}],
                }
            ]

            # Now we will use a second llM call to analyse the result of the query.
            response = bedrock_converseapi_completion(
                model_id=DEFULT_LLM_MODEL_ID, 
                conversation=conversation)

            llm_sql_analysis = response["output"]["message"]["content"][0]["text"]
            # *****************************************************************************************
        else:
            # Generates SQL analysis
            llm_sql_analysis = bedrock_chat_completion(
                model_id=DEFULT_LLM_MODEL_ID,
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )

        print(f"\nLLM SQL Analysis: \n{llm_sql_analysis}")

        return llm_sql_analysis
    except Exception as e:
        print(e)
        return e

##### Embedding Functions

The Class `AmazonBedrockEmbeddingFunction` initializes an embedding function with `Amazon Titan Embedding Model V2` that integrates with ChromaDB . This class can be further extended to add support for other embedding models available on Amazon Bedrock.

In [22]:
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        session: "boto3.Session",
        model_name: str = "amazon.titan-embed-text-v2:0",
        **kwargs: Any,
    ):
        """Initialize AmazonBedrockEmbeddingFunction.

        Args:
            session (boto3.Session): The boto3 session to use.
            model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
            **kwargs: Additional arguments to pass to the boto3 client.

        Example:
            >>> import boto3
            >>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
            >>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
            >>> texts = ["Hello, world!", "How are you?"]
            >>> embeddings = bedrock(texts)
        """

        self._model_name = model_name

        self._client = session.client(
            service_name="bedrock-runtime",
            **kwargs,
        )

    def __call__(self, input: Documents) -> Embeddings:
        accept = "*/*"
        content_type = "application/json"
        embeddings = []
        for text in input:
            input_body = {"inputText": text, "dimensions": 512, "normalize": True}
            body = json.dumps(input_body)
            response = self._client.invoke_model(
                body=body,
                modelId=self._model_name,
                accept=accept,
                contentType=content_type,
            )
            embedding = json.load(response.get("body")).get("embedding")
            embeddings.append(embedding)
        return embeddings

In [23]:
class AmazonSageMakerEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        **kwargs: Any,
    ):
        """Initialize AmazonSageMakerEmbeddingFunction.

        Args:
            **kwargs: Additional arguments to pass to the sagemaker embedding function.

        Example:
            >>> sagemaker = AmazonBedrockEmbeddingFunction()
            >>> text_inputs = ["Hello, world!", "How are you?"]
            >>> embeddings = sagemaker(texts)
        """


    def __call__(self, input: Documents) -> Embeddings:
        accept = "application/json"
        content_type = "application/json"

        embeddings = []
        for text in input:
            input_body = {"text_inputs": text, "mode": "embedding"}
            body = json.dumps(input_body).encode('utf-8')
            response = embedding_predictor.predict(
                body,
                {
                    "ContentType": content_type,
                    "Accept": accept,
                }
            )
            embedding = response.get("embedding")
            embeddings.append(embedding)
        return embeddings

## Few-Shot Text-to-SQL
With our database and tables filled with data, we're now ready to walk through the Few-Shot Text-to-SQL approach. We'll start by building some helper functions.

## Analyzing a Single Table with Few-Shot Learning

### Step 1: Create a Few-Shot Prompt
Here, we design our prompt template that will account for our question and answer, and formatted correctly for use with Llama 3.1 models.

First, we create a `system prompt` containing two parts:

1. `table_schema`. This is a description of the structure of the database table, including the name of the table, the names of the columns within each table, and the data types of each column. This information helps Llama 3.1 to understand the organization and contents of the table.

2. `question`. This is the specific request or information that the user wants to obtain from the table.

By including both the table schema and the user's question in the system prompt, we provide Llama 3.1 model a complete understanding of the table structure and the user's desired output.

In [24]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a {dbtype} query expert whose output is a valid sql query.

Only use the following tables:

It has the following schema:
<table_schema>
{table_schema}
<table_schema>

Please construct a valid SQL statement to answer the following the question, return only the {dbtype} query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]

# Format instructions in LLM specific format
# Call the format instructions to get System Prompt and User Prompt into separate variables.
# For Converse API we don't need to provide inputs in a model specific format
tmp_sql_sys_prompt, sysPt, userPt = format_instructions(instructions)

Next, we create a new `system prompt` containing two parts:

1. `query_results` represents the SQL query results after executing the prompt `tmp_sql_sys_prompt`. This is the raw data that Llama 3 model will use to generate its analysis.

2. `question`. This specifies the type of analysis or insight that the user wants Llama 3 model to provide based on the SQL query results.

By combining the SQL query results and the user's question into a single system prompt, we provide Llama 3 model all the information it needs to understand the context and provide a comprehensive analysis tailored to the user's request.

In [25]:
instructions = [
    {
        "role": "user",
        "content": """Given the following SQL query results:
{query_results}

And the original question:
{question}

Please provide an analysis and interpretation of the results to answer the original question.
"""
    }
]

QNA_SYS_PROMPT, sysPtI, userPtI = format_instructions(instructions)

Building on our last prompt, we'll now add a Single Shot example to our context to better hint the model what we expect for a response.

### Step 2: Execute Few Shot Prompts
The following cells will demonstrate different questions asked in natural language and the SQL generated by the LLM. The output is contained between the `<sql>` and `</sql>` tags

In [26]:
def Switch_DB_Schema(DatabaseType: str, SchemaNeeded: str):
    """
    Baesd on the database provided, returns the schema appropriate to the object provided

    Args:
        DatabaseType (str): mysql / postgresql
        SchemaNeeded (str): table_schema_airplanes / 

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """
    match DatabaseType:
        case "mysql":
            match SchemaNeeded:
                case "table_schema_airplanes":
                    schema_for_object = table_schema_airplanes
                case "table_schema_flights":
                    schema_for_object = table_schema_flights
        case "postgresql":
            match SchemaNeeded:
                case "table_schema_airplanes":
                    schema_for_object = pg_table_schema_airplanes
                case "table_schema_flights":
                    schema_for_object = pg_table_schema_flights
    return schema_for_object

In [28]:
%%time
# Business question
question = "What is the total count of airplanes?"

# Amazon Bedrock Invoke API expects a fully formatted prompt.
# Amazon Bedrock Converse API expects regular text
if (llm_selected_service == 'Amazon Bedrock') or (llm_selected_service == 'Amazon SageMaker'):
    tmp_sql_sys_prompt1 = tmp_sql_sys_prompt
    qna_sys_prompt1 = QNA_SYS_PROMPT
elif llm_selected_service == 'Amazon Bedrock Converse API':
    tmp_sql_sys_prompt1 = sysPt
    qna_sys_prompt1 = userPtI

SQL_SYS_PROMPT_s1 = PromptTemplate.from_template(tmp_sql_sys_prompt1).format(
    question=question,
    table_schema=Switch_DB_Schema(llm_selected_db, "table_schema_airplanes"),
    dbtype=llm_selected_db
)
print(SQL_SYS_PROMPT_s1)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT_s1,
    qna_sys_prompt=qna_sys_prompt1,
    DatabaseType=llm_selected_db
)


<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a mysql query expert whose output is a valid sql query.

Only use the following tables:

It has the following schema:
<table_schema>
CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)

<table_schema>

Please construct a valid SQL statement to answer the following the question, return only the mysql query in between <sql></sql>.<|eot_id|><|start_header_id|>user<|end_header_id|>
What is the total count of airplanes?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 
SELECT COUNT(Airplane_id) FROM airline_db.airplanes;

sql_results = 
(20,)

Calling LLM on [Amazon Bedrock] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
To answer the original qu

## Analyzing Multiple Table with Few-Shot Learning

### Step 1: Create a Few-Shot Prompt
Now, let's try the same approach using two tables.

First, we create a `system prompt` containing the same placeholders as before and including two table schemas. 

In [29]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a {dbtype} query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
<table_schema>
{table_schema1}
<table_schema>

<table_schema>
{table_schema2}
<table_schema>
<table_schemas>

Please construct a valid SQL statement to answer the following the question, return only the {dbtype} query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]

tmp_sql_sys_prompt, sysPt, userPt = format_instructions(instructions)


### Step 2: Execute Few Shot Prompts

In [30]:
# Business question
question = "What is the total count of flights per producer?"

# Amazon Bedrock Invoke API expects a fully formatted prompt.
# Amazon Bedrock Converse API expects regular text
if (llm_selected_service == 'Amazon Bedrock') or (llm_selected_service == 'Amazon SageMaker'):
    tmp_sql_sys_prompt1 = tmp_sql_sys_prompt
    qna_sys_prompt1 = QNA_SYS_PROMPT
elif llm_selected_service == 'Amazon Bedrock Converse API':
    tmp_sql_sys_prompt1 = sysPt
    qna_sys_prompt1 = userPtI

SQL_SYS_PROMPT_s1 = PromptTemplate.from_template(tmp_sql_sys_prompt1).format(
    question=question,
    table_schema1=Switch_DB_Schema(llm_selected_db, "table_schema_airplanes"),
    table_schema2=Switch_DB_Schema(llm_selected_db, "table_schema_flights"),
    dbtype=llm_selected_db
)

print(SQL_SYS_PROMPT_s1)

results = get_llm_sql_analysis(
    question=question,
    sql_sys_prompt=SQL_SYS_PROMPT_s1,
    qna_sys_prompt=qna_sys_prompt1,
    DatabaseType=llm_selected_db
)


<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a mysql query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
<table_schema>
CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)

<table_schema>

<table_schema>
CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)

<table_schema>
<table_schemas>

Please construct a valid SQ

## Limitations of Few-Shot Learning

Few-Shot Learning for text-to-SQL tasks, where a language model is trained on a limited number of examples to translate natural language queries into SQL queries, faces significant limitations. One of the key challenges is selecting the appropriate table schema that aligns with the user's question.

In a real-world scenario, databases often consist of numerous tables with intricate relationships, making it difficult for the model to identify the relevant tables and columns required to answer a given query accurately. 

To address this issue, we propose incorporating ChromaDB to facilitate the retrieval of table schemas that are tailored to the user's question.

Here's how ChromaDB can help overcome the table schema selection challenge:

**Table Schema Retrieval**: Each table schema in the database can be converted into a dense vector embedding, capturing its structural information and relationships. The top-ranked table schemas are retrieved and provided as input to the text-to-SQL model, significantly increasing the likelihood of generating accurate SQL queries.

We will review this approach further in the next section.

## Few-shot text-to-SQL powered by ChromaDB

Here, we will use ChromaDB and the few-shot technique to effeciently retrieve table schemas for better performance and generalization across different databases and query types.

## Schema Retrieval

In this approach, we will store only the table schemas in ChromaDB.

### Step 1: Data Preprocessing

The first step is to preprocess the data and create a document that will be ingested into ChromaDB. The final doc clearly separates the table schemas by using XML tags such as `<table_schema></table_schema>`.

In [31]:
# For MySQL
# The doc includes a structure format for clearly identifying the table schemas
doc1 = "<table_schemas>\n"
doc1 += f"<table_schema>\n {settings_airplanes['table_schema']} \n</table_schema>\n".strip()
doc1 += "\n</table_schemas>"
print(doc1)

# The doc includes a structure format for clearly identifying the table schemas
doc2 = "<table_schemas>\n"
doc2 += f"<table_schema>\n {settings_flights['table_schema']} \n</table_schema>\n".strip()
doc2 += "\n</table_schemas>"
print(doc2)

# The doc includes a structure format for clearly identifying the table schemas
doc3 = "<table_schemas>\n"
for table_schema in settings_airplane_flights['table_schemas']:
    doc3 += f"<table_schema>\n {table_schema} \n</table_schema>\n"
doc3 += "\n</table_schemas>".strip()
print(doc3)

<table_schemas>
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
</table_schemas>
<table_schemas>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
</table_schemas>
<table_schemas>
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR

In [32]:
# For PostgreSQL
# The doc includes a structure format for clearly identifying the table schemas
pg_doc1 = "<table_schemas>\n"
pg_doc1 += f"<table_schema>\n {pg_settings_airplanes['table_schema']} \n</table_schema>\n".strip()
pg_doc1 += "\n</table_schemas>"
print(pg_doc1)

# The doc includes a structure format for clearly identifying the table schemas
pg_doc2 = "<table_schemas>\n"
pg_doc2 += f"<table_schema>\n {pg_settings_flights['table_schema']} \n</table_schema>\n".strip()
pg_doc2 += "\n</table_schemas>"
print(pg_doc2)

# The doc includes a structure format for clearly identifying the table schemas
pg_doc3 = "<table_schemas>\n"
for table_schema in pg_settings_airplane_flights['table_schemas']:
    pg_doc3 += f"<table_schema>\n {table_schema} \n</table_schema>\n"
pg_doc3 += "\n</table_schemas>".strip()
print(doc3)

<table_schemas>
<table_schema>
 CREATE TABLE airplanes -- Table name
(
Airplane_id INTEGER, -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
</table_schemas>
<table_schemas>
<table_schema>
 CREATE TABLE flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INTEGER, -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>
</table_schemas>
<table_schemas>
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type

### Step 2: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB.

In [33]:
# Setup Chroma in-memory, for easy prototyping.
chroma_client = chromadb.Client()

In [34]:
# For MySQL
# Create collection using ChromaDB's internal embedding function
collection1 = chroma_client.get_or_create_collection(name="table-schemas-default-embedding", 
                                                     metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
collection1.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": mySQL_db_name, "table_name": table_airplanes},
        {"source": "mysql", "database": mySQL_db_name, "table_name": table_flights},
        {"source": "mysql", "database": mySQL_db_name, "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)

/home/ec2-user/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:02<00:00, 30.7MiB/s]


In [35]:
# For PostgreSQL
# Create collection using ChromaDB's internal embedding function
# Make sure Modify the Collection name for each new Database
pg_collection1 = chroma_client.get_or_create_collection(name="pg-table-schemas-default-embedding", 
                                                        metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
pg_collection1.add(
    documents=[
        pg_doc1,
        pg_doc2,
        pg_doc3
    ],
    metadatas=[
        {"source": "postgresql", "database": pg_db_name, "table_name": pg_table_airplanes},
        {"source": "postgresql", "database": pg_db_name, "table_name": pg_table_flights},
        {"source": "postgresql", "database": pg_db_name, "table_name": f"{pg_table_airplanes}-{pg_table_flights}" }
    ],
    ids=[pg_table_airplanes, pg_table_flights, f"{pg_table_airplanes}-{pg_table_flights}"], # unique for each doc
)

In [36]:
def RunPrompts(DatabaseType: str, ServiceType: str, question: str, mySQLCol, pgCol):
    """
    Helper function that does the following:
    - Picks up the ChromaDB collection (passed as input parameters) based on the DatabaseType
    - Retrieves relevant table schemas from ChromaDB based on the Business Question.
    - Invoke the LLM to return the SQL query for the table schemas retrieved.
    - Run the query against the database.
    - Invoke the LLM to analyse results of the query execution against the Business Question.

    Args:
        DatabaseType (str): mysql / postgresql
        ServiceType (str): 'Amazon Bedrock' / 'Amazon Bedrock Converse API'
        question (str): User's business question
        mySQLCol (object): ChromadB collection for mySQL Database
        pgCol (object): ChromadB collection for mySQL Database

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """
    # Route the query according to the database passed
    if DatabaseType == "mysql":
        # For MySQL DB
        collection_to_use = mySQLCol
    elif DatabaseType == "postgresql":
        # For PostgreSQL DB
        collection_to_use = pgCol
    
    # Query/search 1 most similar results.
    docs = collection_to_use.query(
        query_texts=[question],
        n_results=1
    )

    pattern = r"<table_schemas>(.*)</table_schemas>"
    table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
    print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")
    
    # Amazon Bedrock Invoke API expects a fully formatted prompt.
    # Amazon Bedrock Converse API expects regular text
    if (llm_selected_service == 'Amazon Bedrock') or (llm_selected_service == 'Amazon SageMaker'):
        tmp_sql_sys_prompt1 = tmp_sql_sys_prompt
        qna_sys_prompt1 = QNA_SYS_PROMPT
    elif ServiceType == 'Amazon Bedrock Converse API':
        tmp_sql_sys_prompt1 = sysPt
        qna_sys_prompt1 = userPtI

    SQL_SYS_PROMPT_s1 = PromptTemplate.from_template(tmp_sql_sys_prompt1).format(
        question=question,
        table_schemas=table_schemas,
        dbtype=DatabaseType
    )

    results = get_llm_sql_analysis(
        question=question,
        sql_sys_prompt=SQL_SYS_PROMPT_s1,
        qna_sys_prompt=qna_sys_prompt1,
        DatabaseType=DatabaseType
    )


### Step 3: Create a Few-Shot Prompt

Now, we'll use a few-shot approach using the retrieved tables from ChromaDB.

First, we create a `system prompt` containing a placeholder including any number of table schemas for ChromaDB. 

In [37]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a {dbtype} query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
{table_schemas}
<table_schemas>

Always combine the database name and table name to build your queries. You must identify these two values before proving a valid SQL query.

Please construct a valid SQL statement to answer the following the question, return only the {dbtype} query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]

tmp_sql_sys_prompt, sysPt, userPt = format_instructions(instructions)


### Step 4: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be used for the SQL llm analysis.

In [38]:
%%time
# Business question
question = "What is the total count of airplanes?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection1, pg_collection1)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT COUNT(Airplane_id) AS total_airplanes
FROM airline_db.airplanes;


sql_results = 
(20,)

Calling LLM on [Amazon Bedrock] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
Unfortunately, the provided SQL query result doesn't contain enough information to directly answer the original question, "What is 

In [39]:
%%time
# Business question
question = "What is the total count of flights per producer?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection1, pg_collection1)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT Producer, COUNT(*) as Total_Flights 
FROM airline_db.flights 
GROUP BY Producer;

1054 (42S22): Unknown column 'Producer' in 'field li

### Step 5: Conclusion

We can observe that ChromaDB was unable to retrieve the correct table schema for the "airplanes" table. The issue arose due to a confusion caused by a foreign key reference. Specifically, ChromaDB retrieved the "flights" table instead of the "airplanes" table because the "flights" table contains a field called "Airplane_id" which references the "airplanes" table as a foreign key. This foreign key reference led to the confusion, resulting in ChromaDB retrieving the wrong table.

To mitigate this issue, we will use a more robust embedding model like the `Amazon Titan Embedding Model`.

We will see this in action in the following section.

## Enhanced Schema Retrieval with an Embedding Model

In this approach, we will use an embedding model from AWS.

### Step 1: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB using your selected embedding model (`Amazon SageMaker BGE Large English` or `Amazon Titan Embedding V2`).

In [40]:
# Create embedding function with AWS
if embedding_selected_service == "Amazon SageMaker":
    aws_ef = AmazonSageMakerEmbeddingFunction()
else:
    session = boto3.Session()
    aws_ef = AmazonBedrockEmbeddingFunction(
        session=session,
        model_name=DEFAULT_EMBEDDING_MODEL_ID
    )
    

In [41]:
# Create collection using Amazon Titan Embedding model
collection2 = chroma_client.get_or_create_collection(name="table-schemas-aws-embedding-model", 
                                                     embedding_function=aws_ef, 
                                                     metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
collection2.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": mySQL_db_name, "table_name": table_airplanes},
        {"source": "mysql", "database": mySQL_db_name, "table_name": table_flights},
        {"source": "mysql", "database": mySQL_db_name, "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)


In [42]:
# Create collection using Amazon Titan Embedding model
pg_collection2 = chroma_client.get_or_create_collection(name="pg-table-schemas-aws-embedding-model", 
                                                        embedding_function=aws_ef, 
                                                        metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
pg_collection2.add(
    documents=[
        pg_doc1,
        pg_doc2,
        pg_doc3
    ],
    metadatas=[
        {"source": "postgresql", "database": pg_db_name, "table_name": pg_table_airplanes},
        {"source": "postgresql", "database": pg_db_name, "table_name": pg_table_flights},
        {"source": "postgresql", "database": pg_db_name, "table_name": f"{pg_table_airplanes}-{pg_table_flights}" }
    ],
    ids=[pg_table_airplanes, pg_table_flights, f"{pg_table_airplanes}-{pg_table_flights}"], # unique for each doc
)


### Step 2: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [43]:
%%time
# Business question
question = "What is the total count of airplanes?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 
SELECT COUNT(Airplane_id) FROM airline_db.airplanes;

sql_results = 
(20,)

Calling LLM on [Amazon Bedrock] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
Unfortunately, the provided SQL query result `(20,)` is not a complete or meaningful result set. It appears to be a single value enclosed in parentheses, which is not a typical format for SQL query results.

However, assuming the result is a single value representing the count of airplanes, I'll provide an analysis and interpretation.

**Analysis:**

The result `(20,)` suggests that the SQL query has returned a single value, which is `20`. This val

In [44]:
%%time
# Business question
question = "What is the total count of flights per producer?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT Producer, COUNT(*) as Total_flights
FROM airline_db.flights
JOIN airline_db.airplanes ON flights.Airplane_id = airplanes.Airplane_id
G

For this second example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [45]:
%%time
# Business question
question = "How many unique airplane producers are represented in the database?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT COUNT(DISTINCT Producer) FROM airline_db.airplanes


sql_results = 
(4,)

Calling LLM on [Amazon Bedrock] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
Based on the SQL query result `(4,)`, it appears that the query is returning a single value, which is `4`.

In the context of the original question, "How many unique airplane producers are represented in the database?", the result suggests that there are 4 unique airplane producers in the database.

This means that the database contains information about 4 different airplane manufacturers, and these are the only unique producers represented i

For this third example, we expect the table `flights` to be included for the SQL llm analysis.

In [46]:
%%time
# Business question
question = "Get the total number of flights scheduled for each destination, grouped by arrival date"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT Arrival_date, COUNT(Destination) as Total_flights
FROM airline_db.flights
GROUP BY Arrival_date, Destination


sql_results = 
('2023-06-15', 1)
('2023-07-02', 1)
('2023-06-24', 1)
('2023-06-19', 1)
('2023-06-27', 1)
('2023-06-25', 1)
('2023-06-29', 1)
('2023-07-03', 1)
('2023-06-17', 1)
('2023-06-21', 1)
('2023-06-28', 1)
('2023-06-20', 1)
('2023-06-22', 1)
('2

For this fourth example, we expect the table `airplanes` and `flights` to be included for the SQL llm analysis.

In [47]:
%%time
# Business question
question = "Find the airplane IDs and producers for airplanes that have flown to New York"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE airline_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE airline_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock] to generate SQL for [mysql]


LLM SQL Query: 

SELECT A.Airplane_id, A.Producer 
FROM airline_db.airplanes A 
JOIN airline_db.flights F ON A.Airplane_id = F.Airplane_id 
WHERE F.Destinatio

### Step 3: Conclusion

We can observe that ChromaDB and `Amazon Titan Embedding` model were able to retrieve the correct table schemas for the previous examples.  After successfully implementing these solutions, the issue of incorrectly retrieved table schemas due to foreign key confusions was effectively addressed. The data retrieval process became more accurate and reliable, ensuring that the correct table schemas were consistently retrieved, even in the presence of complex table relationships and foreign key references.

## Clean Up Resources

In [48]:
%%time
# Delete resources
if llm_selected_service == 'Amazon SageMaker':
    llm_predictor.delete_model()
    llm_predictor.delete_endpoint()

if embedding_selected_service == 'Amazon SageMaker':
    embedding_predictor.delete_model()
    embedding_predictor.delete_endpoint()

CPU times: user 4 μs, sys: 0 ns, total: 4 μs
Wall time: 6.91 μs


# Thank you!