# Text-to-SQL using Llama 3.2 and ChromaDB

---

## Introduction

This notebook introduces a versatile approach that leverages Llama 3.2 models on Amazon Bedrock, including advanced prompt engineering, to convert natural language questions into executable SQL queries. Our approach generates SQL queries capable of joining data from tables across multiple databases, enabling information retrieval from complex database structures. This multi-database capability is crucial in real-world scenarios where data is often distributed across various tables with intricate relationships, and queries need to combine information from multiple sources to provide comprehensive insights.

---

## Llama 3.2 Model Selection

There are Four Llama 3.2 models available on Amazon Bedrock:

#### Llama 3.2 1B (text input)
The most lightweight model in the Llama 3.2 collection of models, perfect for retrieval and summarization for edge devices and mobile applications. This model is ideal for the following use cases: personal information management and multilingual knowledge retrieval.

#### Llama 3.2 3B (text input)
Designed for applications requiring low-latency inferencing and limited computational resources. It excels at text summarization, classification, and language translation tasks. This model is ideal for the following use cases: mobile AI-powered writing assistants and customer service applications.

#### Llama 3.2 11B Vision (text + image input)
Well-suited for content creation, conversational AI, language understanding, and enterprise applications requiring visual reasoning. The model demonstrates strong performance in text summarization, sentiment analysis, code generation, and following instructions, with the added ability to reason about images.

#### Llama 3.2 90B Vision (text + image input)
Meta’s most advanced model, ideal for enterprise-level applications. This model excels at general knowledge, long-form text generation, multilingual translation, coding, math, and advanced reasoning. 

For more information, refer to the following links:

1. [Llama 3.2 Model Cards and Prompt Formats](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2)
2. [Amazon Bedrock Pricing Page](https://aws.amazon.com/bedrock/pricing/)

---

## Approach to the Text-to-SQL Problem

This notebook covers the following approaches

### Few-shot text-to-SQL powered by ChromaDB (Schema Retrieval vs Enhance Schema Retrieval with Sample Questions)

This approach leverages ChromaDB, a vector database, to assist the few-shot text-to-SQL translation process. ChromaDB stores the database schema information, which includes table names, column names, and their descriptions. When a natural language question is provided, the model can retrieve relevant schema information from ChromaDB to aid in generating the SQL query. ChromaDB can be used in two ways:

1. **Using the Default Embedding**: In this method, the default embedding model is used.

2. **Using a Custom Embedding**: In this method, we show the option of using 1/ HuggingFace BGE Large EN Embedding model - on Amazon SageMaker JumpStart or 2/ Amazon Titan v2 Embedding model - on Amazon Bedrock.

By leveraging ChromaDB, the few-shot text-to-SQL translation process can benefit from efficient schema and sample data retrieval, potentially leading to better performance and generalization across different databases and query types.

**Note:** For simple Few-shot text-to-SQL (Single Table vs Multiple Tables), please review [this notebook](https://github.com/aws-samples/Meta-Llama-on-AWS/blob/917ee12690a8e43eb9699b88662e7144ce9acac4/text2sql-recipes/llama3-1-chromadb-text2sql.ipynb).


---

## Objectives

This notebook will provide code snippets to assist with implementing two differents approaches to converting a natural language question into a SQL query. The query will be executed in the database to answer the original question.

For accessing LLMs the following code snippets are provided:
1. Amazon Bedrock Invoke API
2. Amazon Bedrock Converse API
3. SageMaker Jumpstart


---

## Contents

1. [Getting Started](#Getting-Started)
    + [Install Dependencies](#Step-0:-Install-Dependencies)
    + [Select Model Hosting Service](#Step-1:-Select-Hosting-Model-Service)
    + [Get Database and Schema details](Step-2:-Get-Database-and-Schema-details)
    + [Create helper functions](Step-3:-Create-helper-functions)
1. [Few-shot text-to-SQL powered by ChromaDB](#Few-shot-text-to-SQL-powered-by-ChromaDB)
    + [Schema Retrieval](#Schema-Retrieval)
    + [Data Preprocessing](#Step-1:-Data-Preprocessing)
    + [Ingest docs into ChromaDB](#Step-2:-Ingest-docs-into-ChromaDB)
    + [Create a Few-Shot Prompt](#Step-3:-Create-a-Few-Shot-Prompt)
    + [Execute Few-Shot Prompts](#Step-4:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-5-Conclusion)
1. [Enhanced Schema Retrieval with ChromaDB and an Embedding Model](#Enhanced-Schema-Retrieval-with-an-Embedding-Model)
    + [Ingest docs into ChromaDB](#Step-1:-Ingest-docs-into-ChromaDB)
    + [Execute Few-Shot Prompts](#Step-2:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-3-Conclusion)
---

### Tools

+ AWS Python SDKs [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) to be able to submit API calls to [Amazon Bedrock](https://aws.amazon.com/bedrock/).

+ [LangChain](https://python.langchain.com/v0.1/docs/get_started/introduction/) is a framework that provides off the shelf components to make it easier to build applications with large language models. It is supported in multiple programming languages, such as Python, JavaScript, Java and Go. In this notebook, LangChain is used to build a prompt template.

+ [ChromaDB](https://www.trychroma.com/) is a vector database that enables efficient semantic search, storage, and retrieval of unstructured data like text, images, and audio. It's designed to work well with large language models (LLMs) and provides a simple and scalable way to build applications that can search and retrieve relevant information from vast amounts of data.

+ RDS (Relational Database Service) for [MySQL](https://aws.amazon.com/rds/mysql/) is a managed database service provided by Amazon Web Services (AWS). RDS for MySQL simplifies the setup, operation, and scaling of MySQL databases.

---

## Pre-requisites

1. It is mandatory to have set up the database and sample data prior to using [this notebook](llama3-2-chromadb-text2sql-DB-Setup.ipynb).
2. Use kernel either `conda_python3`, `conda_pytorch_p310` or `conda_tensorflow2_p310`.
3. Install the required packages.
4. Access to the LLM API. 


### Amazon Bedrock Deployment

In this notebook, Llama 3.2 1B model is used. However, you can easily switch between the other Llama 3.2 models to evaluate the responses. By deploying the notebook through our cloudformation template, it is granted the appropriate IAM permissions to send API request to Bedrock. 

Refer [here](https://aws.amazon.com/blogs/aws/introducing-llama-3-2-models-from-meta-in-amazon-bedrock-a-new-generation-of-multimodal-vision-and-lightweight-models/) for details on how Amazon Bedrock provides access to Meta’s Llama 3.2.

### SageMaker Deployment

#### Changing instance type
---
Models are supported on the following instance types:

 - Llama3.2 1B Text Generation: `ml.g5.xlarge`, `ml.g5.2xlarge`, `ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.g5.12xlarge`, `ml.g5.24xlarge`, `ml.g5.48xlarge`, `ml.g6.xlarge`, `ml.g6.2xlarge`, `ml.g6.4xlarge`, `ml.g6.8xlarge`, `ml.g6.12xlarge`, `ml.g6.24xlarge`, `ml.g6.48xlarge`, `ml.p4d.24xlarge` and `ml.p5.48xlarge`
 - Llama3.2 3B Text Generation: `ml.g5.48xlarge`, `ml.g6.48xlarge`, `ml.p5.48xlarge`, `ml.p4d.24xlarge` and `ml.p5.48xlarge`
 - BGE Large En v1.5: `ml.g5.2xlarge`, `ml.c6i.xlarge`,`ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.p3.2xlarge`, and `ml.g4dn.2xlarge`

**Note:** By default, the JumpStartModel class selects a default instance type available in your region. If you would like to use a different instance type, you can do so by specifying instance type in the JumpStartModel class.

`my_model = JumpStartModel(model_id=model_id, instance_type="ml.g5.12xlarge")`

---

## Getting Started

### Step 0: Install Dependencies

Here, we will install all the required dependencies to run this notebook.

In [1]:
!pip install boto3==1.35.32 -qU --force --quiet --no-warn-conflicts
!pip install mysql-connector-python==8.4.0 -qU --force --quiet --no-warn-conflicts
!pip install langchain==0.2.5 -qU --force --quiet --no-warn-conflicts
!pip install chromadb==0.5.0 -qU --force --quiet --no-warn-conflicts
!pip install numpy==1.26.4 -qU --force --quiet --no-warn-conflicts
!pip install psycopg2==2.9.9 -qU --force --quiet --no-warn-conflicts


**Note:** *When installing libraries using the pip, you may encounter errors or warnings during the installation process. These are generally not critical and can be safely ignored. However, after installing the libraries, it is recommended to restart the kernel or computing environment you are working in. Restarting the kernel ensures that the newly installed libraries are loaded properly and available for use in your code or workflow.*

<div class='alert alert-block alert-info'><b>NOTE:</b> Restart the kernel with the updated packages that are installed through the dependencies above</div>

In [None]:
# Restart the kernel
import os
os._exit(00)

#### Import the required modules to run the notbook

In [1]:
import boto3
import chromadb
from chromadb.api.types import (
    Documents,
    EmbeddingFunction,
    Embeddings,
)
import json
from langchain import PromptTemplate
import mysql.connector as MySQLdb
import re
from typing import Dict, List, Any
import yaml


In [2]:
# Setup Bedrock Client
bedrock_client = boto3.client(
    service_name='bedrock-runtime'
)

### Step 1: Select Hosting Model Service

Here, you can select to run this notebook using SageMaker JumpStart or Amazon Bedrock.

In [3]:
def ask_for_service():
    service = input("Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon Bedrock Converse API (C) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK (Invoke API)', '']:
        return 'Amazon Bedrock'
    elif service in  ['C', 'BEDROCK (Converse API)']:
        return 'Amazon Bedrock Converse API'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
llm_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the LLM for this notebook using [{llm_selected_service}].")

Do you want to run the LLM for this notebook using Amazon Bedrock (B) or Amazon Bedrock Converse API (C) or Amazon SageMaker JumpStart (S)? (default: B)  C


You have chosen to run the LLM for this notebook using [Amazon Bedrock Converse API].


In [4]:
def ask_for_service():
    service = input("Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B) ").strip().upper()
    if service in ['S', 'SAGEMAKER']:
        return 'Amazon SageMaker'
    elif service in ['B', 'BEDROCK', '']:
        return 'Amazon Bedrock'
    else:
        print("Invalid input. Using Amazon Bedrock by default.")
        return 'Amazon Bedrock'

# Call the function and get the selected service
embedding_selected_service = ask_for_service()

# Print the selected service
print(f"You have chosen to run the Embedding for this notebook using [{embedding_selected_service}].")

Do you want to run the Embedding for this notebook using Amazon Bedrock (B) or Amazon SageMaker JumpStart (S)? (default: B)  B


You have chosen to run the Embedding for this notebook using [Amazon Bedrock].


In [5]:
%%time
if llm_selected_service == 'Amazon SageMaker':
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Specify the model ID for the Meta Llama 3.2 Instruct LLM model
    llama3_2_1b_id = "meta-textgeneration-llama-3-2-1b-instruct"
    llama3_2_3b_id = "meta-textgeneration-llama-3-2-3b-instruct"
    llama3_2_11b_id = "meta-vlm-llama-3-2-11b-vision-instruct"
    llama3_2_90b_id = "meta-vlm-llama-3-2-90b-vision-instruct"
    
    DEFULT_LLM_MODEL_ID = llama3_2_1b_id
    
    model = JumpStartModel(model_id=DEFULT_LLM_MODEL_ID, instance_type="ml.g5.8xlarge")
    
    llm_predictor = model.deploy(accept_eula=True)
    
    print(f"\nLLM SageMaker Endpoint Name: [{llm_predictor.endpoint_name}].\n")
else:
    llm_predictor = None
    
    # Specify the model ID for the Meta Llama 3.2 Instruct LLM model
    llama3_2_1b_id = "us.meta.llama3-2-1b-instruct-v1:0"
    llama3_2_3b_id = "us.meta.llama3-2-3b-instruct-v1:0"
    llama3_2_11b_id = "us.meta.llama3-2-11b-instruct-v1:0"
    llama3_2_90b_id = "us.meta.llama3-2-90b-instruct-v1:0"

    DEFULT_LLM_MODEL_ID = llama3_2_1b_id
    
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

if embedding_selected_service == 'Amazon SageMaker':
    # Import the JumpStartModel class from the SageMaker JumpStart library
    from sagemaker.jumpstart.model import JumpStartModel

    # Deploy BGE Large EN embedding model on Amazon SageMaker JumpStart:
    # Specify the model ID for the HuggingFace BGE Large EN Embedding model
    DEFAULT_EMBEDDING_MODEL_ID = "huggingface-sentencesimilarity-bge-large-en"

    text_embedding_model = JumpStartModel(model_id=DEFAULT_EMBEDDING_MODEL_ID, instance_type="ml.g5.4xlarge" )
    
    embedding_predictor = text_embedding_model.deploy()
    
    print(f"\nLLM SageMaker Endpoint Name: [{embedding_predictor.endpoint_name}].\n")
else:
    embedding_predictor = None
    DEFAULT_EMBEDDING_MODEL_ID = "amazon.titan-embed-text-v2:0"

CPU times: user 4 μs, sys: 1 μs, total: 5 μs
Wall time: 7.15 μs


In [6]:
# This notebook example uses Amazon RDS MySQL DB.
llm_selected_db = "mysql"

### Step 2: Get Database and Schema details

Here, we retrieve the services that are already deployed as a part of the cloudformation template to be used in building the application. The services include,

+ Secret ARN with RDS for MySQL Database credentials
+ Database Endpoints

In [7]:
stackname = "text2sql"  # If your stack name differs from "text2sql", please modify.

In [8]:
cfn = boto3.client('cloudformation')

response = cfn.describe_stack_resources(
    StackName=stackname
)
cfn_outputs = cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']

# Get rds secret arn and database endpoint from cloudformation outputs
for output in cfn_outputs:
    if 'SecretArnMySQL' in output['OutputKey']:
        mySQL_secret_id = output['OutputValue']

    if 'DatabaseEndpointMySQL' in output['OutputKey']:
        mySQL_db_host = output['OutputValue']


In [9]:
secrets_client = boto3.client('secretsmanager')

# Get MySQL credentials from Secrets Manager
credentials = json.loads(secrets_client.get_secret_value(SecretId=mySQL_secret_id)['SecretString'])

# Get password and username from secrets
mySQL_db_password = credentials['password']
mySQL_db_user = credentials['username']


Establish the database connection (MySQL DB)

In [10]:
mySQL_db_conn = MySQLdb.connect(
    host=mySQL_db_host,
    user=mySQL_db_user,
    password=mySQL_db_password
)

#### Load table schema settings

In [11]:
def load_settings(file_path):
    """
    Reads a YAML file and returns its contents as a Python object.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        obj: The contents of the YAML file as a Python object.
    """
    try:
        with open(file_path, 'r') as file:
            data = yaml.safe_load(file)
        return data
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' does not exist.")
    except yaml.YAMLError as exc:
        print(f"Error: Failed to parse the YAML file '{file_path}': {exc}")

In [12]:
# MySQL Table Setup

# Load table settings - database: equipment_db | table_name: airplanes
settings_airplanes = load_settings('schemas/airplanes_ms.yml')
table_airplanes = settings_airplanes['table_name']
table_schema_airplanes = settings_airplanes['table_schema']
db_airplanes = settings_airplanes['database']

# Load table settings - database: transport_db | table_name: flights
settings_flights = load_settings('schemas/flights_ms.yml')
table_flights = settings_flights['table_name']
table_schema_flights = settings_flights['table_schema']
db_flights = settings_flights['database']

# Load table settings
settings_airplane_flights = load_settings('schemas/airplanes-flights_ms.yml')


### Step 3: Create helper functions

To facilate the usability and readability of the SQL Query Analysis made by Llama 3.2, let's create a suite of helper functions.

##### Chat Completion (Invoke LLM and Return response)

The `sagemaker_chat_completion` function uses the SageMaker Endpoint to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [13]:
def sagemaker_chat_completion(
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.1,
    top_p: float = 0.1
) -> str:
    """
    Generates a chat completion from a prompt using the llama3.2 model via Amazon SageMaker JumpStart.

    Args:
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": max_gen_len,
            "temperature": temperature,
            "top_p": top_p,
            "stop": ["<|eot_id|>"]
        }
    }

    # Call the model API to generate the completion
    response = llm_predictor.predict(body)
    completion = response.get('generated_text', '')

    return completion.strip()

The `bedrock_chat_completion` function uses the Bedrock client to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [14]:
def bedrock_chat_completion(
    model_id: str,
    prompt: str,
    max_gen_len: int = 512,
    temperature: float = 0.0,
    top_p: float = 0.0
) -> str:
    """
    Generates a chat completion from a prompt using the llama3 model via Amazon Bedrock.

    Args:
        model_id (str): The ID of the llama3 model to use for completion.
        prompt (str): The prompt text to generate completions for.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    body = {
        "prompt": prompt,
        "max_gen_len": max_gen_len,
        "temperature": temperature,
        "top_p": top_p,
    }

    accept = "application/json"
    contentType = "application/json"

    # Convert the body dictionary to JSON string and encode it as bytes
    body_json = json.dumps(body)
    body_bytes = body_json.encode('utf-8')

    # Call the model API to generate the completion
    response = bedrock_client.invoke_model(
        body=body_bytes, modelId=model_id, accept=accept, contentType=contentType
    )
    response_body = response["body"].read()
    response_body = json.loads(response_body)
    completion = response_body.get("generation", "")
    
    inputTokenCount = response_body.get("prompt_token_count", "")
    outputTokenCount = response_body.get("generation_token_count", "")
    stopReason = response_body.get("stop_reason", "")
    
    return completion.strip()

The Function `bedrock_converseapi_completion` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results. This uses the Amazon Bedrock Converse API.

In [15]:
def bedrock_converseapi_completion(
    model_id: str,
    conversation,
    max_tokens: int = 512,
    temperature: float = 0.0,
    top_p: float = 0.0,
    system_prompt=""
):
    """
    Generates a chat completion from a conversation using the llama3.2 model via 
    Amazon Bedrock Converse API.

    Args:
        model_id (str): The ID of the llama3 model to use for completion.
        conversation (object): The conversation so far user/assistant/user/...
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.
        system_prompt: Any instructions to regulate the LLM behaviour.

    Returns:
        response: Object containing the assistant's response
    """  
    try:
        if system_prompt == "":
            # Send the message to the model, using the provided inference configuration.
            response = bedrock_client.converse(
                modelId=model_id,
                messages=conversation,
                inferenceConfig={
                    "maxTokens": max_tokens,
                    "temperature": temperature,
                    "topP": top_p
                },
            )
        else:
            # Send the message to the model, using the provided inference configuration.
            response = bedrock_client.converse(
                modelId=model_id,
                messages=conversation,
                inferenceConfig={
                    "maxTokens": max_tokens,
                    "temperature": temperature,
                    "topP": top_p
                },
                system=[{"text": system_prompt}],
            )

        return response
    
    except (ClientError, Exception) as e:
        print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
        exit(1)

##### Instruction formatting, Query execution and LLM calling

The `format_instructions` function is designed to process the input from Llama 3 models, allowing a conversation between roles such as `system`, `user`, and `assistant`. To see more details about Llama 3 prompt formats, click [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/).

In [16]:
def format_instructions(instructions: List[Dict[str, str]]) -> List[str]:
    """Format instructions where conversation roles must alternate system/user/assistant/user/assistant/...
    Formats the prompt. Returns as a native prompt for Llama and also returns the System and User Prompt 
    texts that can be used with Bedrock Converse API.

    Args:
        instructions (List): A Dictionary of user and system instructions
    Returns:
        str: A formatted string containing the prompt, plaintext system prompt and plaintext user prompt.
    """
    systemPrompt = userPrompt = ""
    prompt: List[str] = []
    
    for instruction in instructions:
        if instruction["role"] == "system":
            prompt.extend(["<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
            systemPrompt = instruction["content"].strip()
        elif instruction["role"] == "user":
            prompt.extend(["<|start_header_id|>user<|end_header_id|>\n", (instruction["content"]).strip(), "<|eot_id|>"])
            userPrompt = instruction["content"].strip()
        else:
            raise ValueError(f"Invalid role: {instruction['role']}. Role must be either 'user' or 'system'.")
    prompt.extend(["<|start_header_id|>assistant<|end_header_id|>\n"])
    return "".join(prompt), "".join(systemPrompt), "".join(userPrompt)

The `execute_query` function will execute SQL queries, typically for retrieving data from a database, and format the results as a string for further processing or display. 

In [17]:
def execute_query(query: str, db_conn) -> str:
    """Execute an SQL query on the database connection and return the results as a string.

    Args:
        query (str): SQL query to execute
        db_conn (Connection object): Connection object to the database where the query needs to be executed.

    Returns:
        str: A formatted string containing the SQL results.
    """
    # Get a cursor from the database connection
    mycursor = db_conn.cursor()

    # Execute the SQL query
    mycursor.execute(query)

    # Fetch all result rows
    result_rows = mycursor.fetchall()

    # Convert result to string with newline between rows
    output_text = '\n'.join([str(x) for x in result_rows])
    return output_text

The Function `get_llm_sql_analysis` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results.

In [18]:
def get_llm_sql_analysis(question: str, sql_sys_prompt: str, qna_sys_prompt: str, DatabaseType: str) -> str:
    """
    Generates an SQL query based on the given question, executes it, and returns an analysis of the results using Llama 3.

    Args:
        question (str): The input question for which an SQL query needs to be generated.
        sql_sys_prompt (str): The prompt to be used for generating the SQL query using Llama 3.
        qna_sys_prompt (str): The prompt to be used for analyzing the SQL query results using Llama 3.
        DatabaseType (enum): The type of database on which query needs to be executed. E.g. MySQL / PostgreSQL etc

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """

    try:
        # *****************************************************************************************
        # 1. Generate SQL Query
        # *****************************************************************************************
        print(f"\nUsing [{llm_selected_service}] to generate SQL for [{DatabaseType}]\n")
        
        if llm_selected_service == 'Amazon SageMaker':
            # Generates SQL query
            completion = sagemaker_chat_completion(
                prompt=sql_sys_prompt
            )
        elif llm_selected_service == 'Amazon Bedrock Converse API':
            # *****************************************************************************************
            # TODO: Perform additional steps for converse API
            conversation = [
                {
                    "role": "user",
                    "content": [{"text": question}],
                }
            ]
            
            response = bedrock_converseapi_completion(
                model_id=DEFULT_LLM_MODEL_ID, 
                conversation=conversation, 
                system_prompt=sql_sys_prompt)

            completion=response["output"]["message"]["content"][0]["text"]
            # *****************************************************************************************

        else:
            # Generates SQL query
            completion = bedrock_chat_completion(
                model_id=DEFULT_LLM_MODEL_ID,
                prompt=sql_sys_prompt
            )
        
        print(f"completion = \n{completion}\n")
        # *****************************************************************************************
        # 2. Extract SQL and Execute
        # *****************************************************************************************

        # Extract the SQL query from the completion returned from the first LLM call.
        pattern = r"<sql>(.*)</sql>"
        
        sr = re.search(pattern, completion, re.DOTALL)

        if sr == None:
            pattern = r"```sql(.*)```"
            sr = re.search(pattern, completion, re.DOTALL)

        llm_sql_query = sr.group(1)
        print(f"\nLLM SQL Query: \n{llm_sql_query}")
    
        # Route the query according to the database passed. Connection object will vary.
        match DatabaseType:
            case "mysql":
                db_conn=mySQL_db_conn
            case "postgresql":
                db_conn=pg_db_conn

        # Execute SQL query based on the database provided.
        sql_results = execute_query(llm_sql_query, db_conn)

        print(f"\nsql_results = \n{sql_results}")

        # *****************************************************************************************
        # 3. Evaluate the response
        # *****************************************************************************************
        
        print(f"\nCalling LLM on [{llm_selected_service}] to Analyze and interpret the results from SQL query in relation to the original question.\n")
        
        # Now we will use a second LLM call to analyse the result of the query.
        if llm_selected_service == 'Amazon SageMaker':
            # Generates SQL analysis
            llm_sql_analysis = sagemaker_chat_completion(
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )
        elif llm_selected_service == 'Amazon Bedrock Converse API':
            # *****************************************************************************************
            # TODO: Append Assistant response.

            # Append user's next question
            # Create a conversation object
            queryAnalyseResults = qna_sys_prompt.format(query_results=sql_results, question=question)

            conversation = [
                {
                    "role": "user",
                    "content": [{"text": queryAnalyseResults}],
                }
            ]

            # Now we will use a second llM call to analyse the result of the query.
            response = bedrock_converseapi_completion(
                model_id=DEFULT_LLM_MODEL_ID, 
                conversation=conversation)

            llm_sql_analysis = response["output"]["message"]["content"][0]["text"]
            # *****************************************************************************************
        else:
            # Generates SQL analysis
            llm_sql_analysis = bedrock_chat_completion(
                model_id=DEFULT_LLM_MODEL_ID,
                prompt=qna_sys_prompt.format(query_results=sql_results, question=question)
            )

        print(f"\nLLM SQL Analysis: \n{llm_sql_analysis}")

        return llm_sql_analysis
    except Exception as e:
        print(f"\nException Encountered: \n{e}\n")
        return e

##### Embedding Functions

The Class `AmazonBedrockEmbeddingFunction` initializes an embedding function with `Amazon Titan Embedding Model V2` that integrates with ChromaDB . This class can be further extended to add support for other embedding models available on Amazon Bedrock.

In [19]:
class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        session: "boto3.Session",
        model_name: str = "amazon.titan-embed-text-v2:0",
        **kwargs: Any,
    ):
        """Initialize AmazonBedrockEmbeddingFunction.

        Args:
            session (boto3.Session): The boto3 session to use.
            model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
            **kwargs: Additional arguments to pass to the boto3 client.

        Example:
            >>> import boto3
            >>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
            >>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
            >>> texts = ["Hello, world!", "How are you?"]
            >>> embeddings = bedrock(texts)
        """

        self._model_name = model_name

        self._client = session.client(
            service_name="bedrock-runtime",
            **kwargs,
        )

    def __call__(self, input: Documents) -> Embeddings:
        accept = "*/*"
        content_type = "application/json"
        embeddings = []
        for text in input:
            input_body = {"inputText": text, "dimensions": 512, "normalize": True}
            body = json.dumps(input_body)
            response = self._client.invoke_model(
                body=body,
                modelId=self._model_name,
                accept=accept,
                contentType=content_type,
            )
            embedding = json.load(response.get("body")).get("embedding")
            embeddings.append(embedding)
        return embeddings

In [20]:
class AmazonSageMakerEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        **kwargs: Any,
    ):
        """Initialize AmazonSageMakerEmbeddingFunction.

        Args:
            **kwargs: Additional arguments to pass to the sagemaker embedding function.

        Example:
            >>> sagemaker = AmazonBedrockEmbeddingFunction()
            >>> text_inputs = ["Hello, world!", "How are you?"]
            >>> embeddings = sagemaker(texts)
        """


    def __call__(self, input: Documents) -> Embeddings:
        accept = "application/json"
        content_type = "application/json"

        embeddings = []
        for text in input:
            input_body = {"text_inputs": text, "mode": "embedding"}
            body = json.dumps(input_body).encode('utf-8')
            response = embedding_predictor.predict(
                body,
                {
                    "ContentType": content_type,
                    "Accept": accept,
                }
            )
            embedding = response.get("embedding")
            embeddings.append(embedding)
        return embeddings

## Few-shot text-to-SQL powered by ChromaDB

We will use ChromaDB and the few-shot technique to retrieve table schemas for better performance and generalization across different databases and query types.

### Schema Retrieval

In this approach, we will store only the table schemas in ChromaDB.

### Step 1: Data Preprocessing

The first step is to preprocess the data and create a document that will be ingested into ChromaDB. The final doc clearly separates the table schemas by using XML tags such as `<table_schema></table_schema>`.

In [21]:
# For MySQL
# The doc includes a structure format for clearly identifying the table schemas
doc1 = "<table_schemas>\n"
doc1 += f"<table_schema>\n {settings_airplanes['table_schema']} \n</table_schema>\n".strip()
doc1 += "\n</table_schemas>"

# The doc includes a structure format for clearly identifying the table schemas
doc2 = "<table_schemas>\n"
doc2 += f"<table_schema>\n {settings_flights['table_schema']} \n</table_schema>\n".strip()
doc2 += "\n</table_schemas>"

# The doc includes a structure format for clearly identifying the table schemas
doc3 = "<table_schemas>\n"
for table_schema in settings_airplane_flights['table_schemas']:
    doc3 += f"<table_schema>\n {table_schema} \n</table_schema>\n"
doc3 += "\n</table_schemas>".strip()


### Step 2: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB.

In [22]:
chroma_client = None

# Setup Chroma in-memory, for easy prototyping.
chroma_client = chromadb.Client()


In [23]:
# Delete collection if exists
try:
    chroma_client.get_collection(name="table-schemas-default-embedding")
except ValueError:
    # Collection does not exist
    pass
else:
    chroma_client.delete_collection(name="table-schemas-default-embedding")


In [24]:
# For MySQL
# Create collection using ChromaDB's internal embedding function
collection1 = chroma_client.get_or_create_collection(name="table-schemas-default-embedding", 
                                                     metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
collection1.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": db_airplanes, "table_name": table_airplanes},
        {"source": "mysql", "database": db_flights, "table_name": table_flights},
        {"source": "mysql", "database": f"{db_airplanes}-{db_flights}", "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)

pg_collection1 = None

/home/ec2-user/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:02<00:00, 31.5MiB/s]


In [25]:
def RunPrompts(DatabaseType: str, ServiceType: str, question: str, mySQLCol, pgCol):
    """
    Helper function that does the following:
    - Picks up the ChromaDB collection (passed as input parameters) based on the DatabaseType
    - Retrieves relevant table schemas from ChromaDB based on the Business Question.
    - Invoke the LLM to return the SQL query for the table schemas retrieved.
    - Run the query against the database.
    - Invoke the LLM to analyse results of the query execution against the Business Question.

    Args:
        DatabaseType (str): mysql / postgresql
        ServiceType (str): 'Amazon Bedrock' / 'Amazon Bedrock Converse API'
        question (str): User's business question
        mySQLCol (object): ChromadB collection for mySQL Database
        pgCol (object): ChromadB collection for mySQL Database

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """
    # Route the query according to the database passed
    if DatabaseType == "mysql":
        # For MySQL DB
        collection_to_use = mySQLCol
    elif DatabaseType == "postgresql":
        # For PostgreSQL DB
        collection_to_use = pgCol
    
    # Query/search 1 most similar results.
    docs = collection_to_use.query(
        query_texts=[question],
        n_results=1
    )

    pattern = r"<table_schemas>(.*)</table_schemas>"
    table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
    print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")
    
    # Amazon Bedrock Invoke API expects a fully formatted prompt.
    # Amazon Bedrock Converse API expects regular text
    if (llm_selected_service == 'Amazon Bedrock') or (llm_selected_service == 'Amazon SageMaker'):
        tmp_sql_sys_prompt1 = tmp_sql_sys_prompt
        qna_sys_prompt1 = QNA_SYS_PROMPT
    elif ServiceType == 'Amazon Bedrock Converse API':
        tmp_sql_sys_prompt1 = sysPt
        qna_sys_prompt1 = userPtI

    SQL_SYS_PROMPT_s1 = PromptTemplate.from_template(tmp_sql_sys_prompt1).format(
        question=question,
        table_schemas=table_schemas,
        dbtype=DatabaseType
    )

    results = get_llm_sql_analysis(
        question=question,
        sql_sys_prompt=SQL_SYS_PROMPT_s1,
        qna_sys_prompt=qna_sys_prompt1,
        DatabaseType=DatabaseType
    )


### Step 3: Create a Few-Shot Prompt

Here, we design our prompt template that will account for our question and answer, and formatted correctly for use with Llama 3.2 models.

First, we create a `system prompt` containing two parts:

1. `table_schemas`. This is a description of the structure of the database table(s), including the name of the table, the names of the columns within each table, and the data types of each column. This information helps Llama 3.2 to understand the organization and contents of the table.

2. `question`. This is the specific request or information that the user wants to obtain from the table.

By including both the table schema and the user's question in the system prompt, we provide Llama 3.2 model a complete understanding of the table structure and the user's desired output.

Now, we'll use a few-shot approach using the retrieved tables from ChromaDB.

First, we create a `system prompt` containing a placeholder including any number of table schemas for ChromaDB. 

In [26]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a {dbtype} query expert whose output is a valid sql query.

Only use the following tables:

It has the following schemas:
<table_schemas>
{table_schemas}
</table_schemas>

Always combine the database name and table name to build your queries. You must identify these two values before proving a valid SQL query.

Please construct a valid SQL statement to answer the following the question, return only the {dbtype} query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]

tmp_sql_sys_prompt, sysPt, userPt = format_instructions(instructions)


Next, we create a new `system prompt` containing two parts:

1. `query_results` represents the SQL query results after executing the prompt `tmp_sql_sys_prompt`. This is the raw data that Llama 3 model will use to generate its analysis.

2. `question`. This specifies the type of analysis or insight that the user wants Llama 3 model to provide based on the SQL query results.

By combining the SQL query results and the user's question into a single system prompt, we provide Llama 3 model all the information it needs to understand the context and provide a comprehensive analysis tailored to the user's request.

In [27]:
instructions = [
    {
        "role": "user",
        "content": """Given the following SQL query results:
{query_results}

And the original question:
{question}

Please provide an analysis and interpretation of the results to answer the original question.
"""
    }
]

QNA_SYS_PROMPT, sysPtI, userPtI = format_instructions(instructions)

### Step 4: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be used for the SQL llm analysis.

In [30]:
%%time
# Business question
question = "What is the total count of airplanes?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection1, pg_collection1)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To answer the question, we need to identify the database name and table name. Based on the given schemas, we can as

In [32]:
%%time
# Business question
question = "What is the total count of flights?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection1, pg_collection1)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To get the total count of flights, you can use the following SQL query:

```sql
SELECT COUNT(*) 
FROM transport_db.

In [33]:
%%time
# Business question
question = "What is the total count of flights per producer?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection1, pg_collection1)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
```sql
SELECT 
    Producer, 
    COUNT(*) as Total_count
FROM 
    transport_db.flights
GROUP BY 
    Producer
ORD

### Step 5: Conclusion

We may observe that ChromaDB is unable to retrieve the correct table schema for the "airplane" table. The issue arose due to a confusion caused by a foreign key reference. Specifically, ChromaDB retrieved the "flights" table instead of the "airplanes" table because the "flights" table contains a field called "Airplane_id" which references the "airplanes" table as a foreign key. This foreign key reference led to the confusion, resulting in ChromaDB retrieving the wrong table.

To mitigate this issue, we will use a more robust embedding model like the `Amazon Titan Embedding Model`.

We will see this in action in the following section.

## Enhanced Schema Retrieval with an Embedding Model

In this approach, we will use an embedding model from AWS.

### Step 1: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB using your selected embedding model (`Amazon SageMaker BGE Large English` or `Amazon Titan Embedding V2`).

In [34]:
# Create embedding function with AWS
if embedding_selected_service == "Amazon SageMaker":
    aws_ef = AmazonSageMakerEmbeddingFunction()
else:
    session = boto3.Session()
    aws_ef = AmazonBedrockEmbeddingFunction(
        session=session,
        model_name=DEFAULT_EMBEDDING_MODEL_ID
    )
    

In [35]:
# Delete collection if exists
try:
    chroma_client.get_collection(name="table-schemas-aws-embedding-model")
except ValueError:
    # Collection does not exist
    pass
else:
    chroma_client.delete_collection(name="table-schemas-aws-embedding-model")


In [36]:
# Create collection using Amazon Titan Embedding model
collection2 = chroma_client.get_or_create_collection(name="table-schemas-aws-embedding-model", 
                                                     embedding_function=aws_ef, 
                                                     metadata={"hnsw:space": "cosine"})

collection2.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": db_airplanes, "table_name": table_airplanes},
        {"source": "mysql", "database": db_flights, "table_name": table_flights},
        {"source": "mysql", "database": f"{db_airplanes}-{db_flights}", "table_name": f"{table_airplanes}-{table_flights}" }
    ],
    ids=[table_airplanes, table_flights, f"{table_airplanes}-{table_flights}"], # unique for each doc
)

pg_collection2 = None

### Step 2: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [37]:
%%time
# Business question
question = "What is the total count of airplanes?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To get the total count of airplanes, you can use the following SQL query:

```sql
SELECT COUNT(*) 
FROM equipment_db.airplanes;
```

This query will return the total count of airplanes in the `airplanes` table.


LLM SQL Query: 

SELECT COUNT(*) 
FROM equipment_db.airplanes;


sql_results = 
(20,)

Calling LLM on [Amazon Bedrock Converse API] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
**Analysis and Interpretation**

The given SQL query results are:

(20,)

This result set contains a single row with two columns: `id` and `count`. The `id` column is not provided, but based on the conte

In [38]:
%%time
# Business question
question = "What is the total count of flights?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To get the total count of flights, you can use the following SQL query:

```sql
SELECT COUNT(*) 
FROM transport_db.flights;
```

This query will return the total count of flights in the `transport_db` database.


LLM SQL Query: 

SELECT COUNT(*) 
FROM transport_db.flights;


sql_results = 
(20,)

Calling LLM on [Amazon Bedrock Converse API] to 

In [41]:
%%time
# Business question
question = "What is the total count of flights per producer?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To answer the question, we need to join the `airplanes` table with the `flights` table based on the `Airplane_id` c

For this example, we expect the table `airplane` to be included for the SQL llm analysis.

In [43]:
%%time
# Business question
question = "How many unique airplane producers are represented in the database?"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
```sql
SELECT COUNT(DISTINCT Producer) AS unique_producers
FROM equipment_db.airplanes;
```

This SQL query will return the number of unique airplane producers represented in the database.


LLM SQL Query: 

SELECT COUNT(DISTINCT Producer) AS unique_producers
FROM equipment_db.airplanes;


sql_results = 
(4,)

Calling LLM on [Amazon Bedrock Converse API] to Analyze and interpret the results from SQL query in relation to the original question.


LLM SQL Analysis: 
**Analysis and Interpretation**

The given SQL query results are:

(4,)

This result set contains a single row with two columns: `id` and `name`. The `id` column is an integer, and the `name`

For this example, we expect the table `flights` to be included for the SQL llm analysis.

In [44]:
%%time
# Business question
question = "Get the total number of flights scheduled for each destination, grouped by arrival date"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
```sql
SELECT 
    Destination, 
    COUNT(*) as Total_Flights, 
    SUM(CASE WHEN Arrival_date = Arrival_time THEN 1 ELSE 0 END) as Total_Scheduled_Flights
FROM 
    transport_db.flights
GROUP BY 
    Destination, 
    Arrival_date
ORDER BY 
    Total_Scheduled_Flights DESC;
```


LLM SQL Query: 

SELECT 
    Destination, 
    COUNT(*) as Tota

For this example, we expect the table `airplanes` and `flights` to be included for the SQL llm analysis.

In [47]:
%%time
# Business question
question = "Find the airplane IDs and producers for airplanes that have flown to New York"

RunPrompts(llm_selected_db, llm_selected_service, question, collection2, pg_collection2)


ChromaDB - Schema Retrieval: 
<table_schema>
 CREATE TABLE equipment_db.airplanes -- Table name
(
Airplane_id INT(10), -- airplane id
Producer VARCHAR(20), -- name of the producer
Type VARCHAR(10), -- airplane type
PRIMARY KEY (Airplane_id)
)
 
</table_schema>
<table_schema>
 CREATE TABLE transport_db.flights -- Table name
(
Flight_number VARCHAR(10), -- flight id
Arrival_time VARCHAR(20), -- arrival time (YYYY-MM-DDTH:M:S)
Arrival_date VARCHAR(20), -- arrival date (YYYY-MM-DD)
Departure_time VARCHAR(20), -- departure time (YYYY-MM-DDTH:M:S)
Departure_date VARCHAR(20), -- departure date (YYYY-MM-DD)
Destination VARCHAR(20), -- destination
Airplane_id INT(10), -- airplane id
PRIMARY KEY (Flight_number),
FOREIGN KEY (Airplane_id) REFERENCES equipment_db.airplanes(Airplane_id)
)
 
</table_schema>

Using [Amazon Bedrock Converse API] to generate SQL for [mysql]

completion = 
To answer this question, we need to join the `airplanes` table with the `flights` table based on the `Airplane_id` 

## Clean Up Resources

In [49]:
%%time
# Delete resources
if llm_selected_service == 'Amazon SageMaker':
    llm_predictor.delete_model()
    llm_predictor.delete_endpoint()

if embedding_selected_service == 'Amazon SageMaker':
    embedding_predictor.delete_model()
    embedding_predictor.delete_endpoint()

CPU times: user 0 ns, sys: 4 μs, total: 4 μs
Wall time: 5.48 μs


## Conclusion and Next Steps

1. We can observe that ChromaDB and `Amazon Titan Embedding` model were able to retrieve the correct table schemas for the previous examples based on a natural language query.
2. We can see that by using `Amazon Titan Embedding` model, we can get a more specific number of tables pertaining to our natural language query.
2. We may also observe that in some cases the LLM is not able to provide a proper query. This could be potentially due to the 1b varient being used. You can experiment with changing the LLM to one that has a higher number of parameters e.g., 3b or 11b or 90b and notice that such concerns are alleviated.
3. We may also observe that at times, the running the sequence once may not yield an expected results, but if we run the same sequence again, we get a proper result. Consider incorporating a retry mechanism.
4. Consider incorporating an SQL query error correction mechanism as well.
5. Experiment with including additional metadata within the Chroma DB metadata embeddings store to define relationships between database objects. This will provide flexibility for further complex scenarios such as: Stored Procedures, Triggers, User Defined Functions etc.
