# Text-to-SQL with Llama 4

---

## Introduction

This notebook introduces a versatile approach that leverages Llama 4 on Amazon SageMaker JumpStart, including one-shot example prompt engineering, to convert natural language questions into executable SQL queries. Our approach generates SQL queries capable of joining data from tables across multiple databases, enabling information retrieval from complex database structures. This multi-database capability is crucial in real-world scenarios where data is often distributed across various tables with intricate relationships, and queries need to combine information from multiple sources to provide comprehensive insights.

---

## Llama 4 Models

**While there are 2 LLama 4 models (Scout and Maverick), today we will focus on the Scout model which can be run on a single-node GPU instance.**

#### Llama 4 Scout (text + image input)
The lighter model in the Llama 4 collection, perfect for applications requiring efficient processing while maintaining high performance. With 17B active parameters (109B total across experts), Scout can run on a single GPU and handles context lengths up to 10M tokens. This model is ideal for:
- Personal information management
- Multilingual knowledge retrieval (supports 12 languages)
- Mobile AI-powered writing assistants
- Customer service applications
- Text summarization and classification tasks

#### Llama 4 Maverick (text + image input)
Meta's most advanced model, designed for enterprise-level applications with 17B active parameters (400B total across experts). With its 128 experts architecture, Maverick excels at:
- General knowledge and reasoning
- Long-form text generation
- Multilingual translation across 12 languages
- Coding and mathematical tasks
- Advanced reasoning and complex problem-solving
- Enterprise applications requiring sophisticated visual reasoning

Both models feature:
- Multimodal capabilities (text + up to 5 images input)
- Support for Arabic, English, French, German, Hindi, Indonesian, Italian, Portuguese, Spanish, Tagalog, Thai, and Vietnamese
- Knowledge cutoff of August 2024
- Optimization for tool-calling and powering agentic systems

For more information, refer to the following link:

1. [Llama 4 Model Cards on GitHub](https://github.com/meta-llama/llama)

---

## Approach to the Text-to-SQL Problem

This notebook covers the following approaches

### Few-shot text-to-SQL powered by ChromaDB (Schema Retrieval vs Enhance Schema Retrieval with Sample Questions)

This approach leverages ChromaDB, a vector database, to assist the few-shot text-to-SQL translation process. ChromaDB stores the database schema information, which includes table names, column names, and their descriptions. When a natural language question is provided, the model can retrieve relevant schema information from ChromaDB to aid in generating the SQL query. ChromaDB can be used by leveraging our Embeddings model:

In this method, we show the option of using HuggingFace BGE Large EN Embedding model on Amazon SageMaker JumpStart to grab the embeddings from the vector store. 

By leveraging ChromaDB, the few-shot text-to-SQL translation process can benefit from efficient schema and sample data retrieval, potentially leading to better performance and generalization across different databases and query types.

---

## Objectives

This notebook will provide code snippets to assist with implementing two differents approaches to converting a natural language question into a SQL query. The query will be executed in the database to answer the original question.

We will be leveraging SageMaker JumpStart to deploy our Llama 4 and embeddings model to build this text2sql solution.

---

## Contents

1. [Getting Started](#Getting-Started)
    + [Install Dependencies](#Step-0:-Install-Dependencies)
    + [Select Model Hosting Service](#Step-1:-Select-Hosting-Model-Service)
    + [Get Database and Schema details](Step-2:-Get-Database-and-Schema-details)
    + [Create helper functions](Step-3:-Create-helper-functions)
1. [Few-shot text-to-SQL powered by ChromaDB](#Few-shot-text-to-SQL-powered-by-ChromaDB)
    + [Schema Retrieval](#Schema-Retrieval)
    + [Data Preprocessing](#Step-1:-Data-Preprocessing)
    + [Ingest docs into ChromaDB](#Step-2:-Ingest-docs-into-ChromaDB)
    + [Create a Few-Shot Prompt](#Step-3:-Create-a-Few-Shot-Prompt)
    + [Execute Few-Shot Prompts](#Step-4:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-5-Conclusion)
1. [Enhanced Schema Retrieval with ChromaDB and an Embedding Model](#Enhanced-Schema-Retrieval-with-an-Embedding-Model)
    + [Ingest docs into ChromaDB](#Step-1:-Ingest-docs-into-ChromaDB)
    + [Execute Few-Shot Prompts](#Step-2:-Execute-Few-Shot-Prompts)
    + [Conclusion](#Step-3-Conclusion)
---

### Tools

+ AWS Python SDKs [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) to be able to submit API calls to [Amazon Bedrock](https://aws.amazon.com/bedrock/).

+ [LangChain](https://python.langchain.com/v0.1/docs/get_started/introduction/) is a framework that provides off the shelf components to make it easier to build applications with large language models. It is supported in multiple programming languages, such as Python, JavaScript, Java and Go. In this notebook, LangChain is used to build a prompt template.

+ [ChromaDB](https://www.trychroma.com/) is a vector database that enables efficient semantic search, storage, and retrieval of unstructured data like text, images, and audio. It's designed to work well with large language models (LLMs) and provides a simple and scalable way to build applications that can search and retrieve relevant information from vast amounts of data.

+ RDS (Relational Database Service) for [MySQL](https://aws.amazon.com/rds/mysql/) is a managed database service provided by Amazon Web Services (AWS). RDS for MySQL simplifies the setup, operation, and scaling of MySQL databases.

---

## Pre-requisites

1. It is mandatory to have set up the database and sample data prior to using [this notebook](llama4-chromadb-text2sql-DB-Setup.ipynb).
2. Use kernel either `conda_python3`, `conda_pytorch_p310` or `conda_tensorflow2_p310`.
3. Install the required packages.

**Additionally, you should add the following IAM policy to your execution role below:**

```json
{
	"Version": "2012-10-17",
	"Statement": [
		{
			"Effect": "Allow",
			"Action": [
				"cloudformation:DescribeStackResources",
				"cloudformation:DescribeStacks",
				"cloudformation:ListStacks"
			],
			"Resource": "*"
		}
	]
}
```


### SageMaker Deployment

#### Changing instance type
---
Models are supported on the following instance types:

 - Llama 4 Scout: `ml.g6e.48xlarge`, `ml.p5.48xlarge`, `ml.p5en.48xlarge`
 - BGE Large En v1.5: `ml.g5.2xlarge`, `ml.c6i.xlarge`,`ml.g5.4xlarge`, `ml.g5.8xlarge`, `ml.p3.2xlarge`, and `ml.g4dn.2xlarge`

**Note:** By default, the JumpStartModel class selects a default instance type available in your region. If you would like to use a different instance type, you can do so by specifying instance type in the JumpStartModel class.

`my_model = JumpStartModel(model_id=model_id, instance_type="ml.p5.48xlarge")`

---

## Getting Started

### Step 0: Install Dependencies

Here, we will install all the required dependencies to run this notebook.

In [None]:
!pip install boto3==1.35.32 -qU --force --quiet --no-warn-conflicts
!pip install mysql-connector-python==8.4.0 -qU --force --quiet --no-warn-conflicts
!pip install langchain==0.2.5 -qU --force --quiet --no-warn-conflicts
!pip install chromadb==0.5.0 -qU --force --quiet --no-warn-conflicts
!pip install numpy==1.26.4 -qU --force --quiet --no-warn-conflicts
!pip install psycopg2==2.9.9 -qU --force --quiet --no-warn-conflicts


**Note:** *When installing libraries using the pip, you may encounter errors or warnings during the installation process. These are generally not critical and can be safely ignored. However, after installing the libraries, it is recommended to restart the kernel or computing environment you are working in. Restarting the kernel ensures that the newly installed libraries are loaded properly and available for use in your code or workflow.*

<div class='alert alert-block alert-info'><b>NOTE:</b> Restart the kernel with the updated packages that are installed through the dependencies above</div>

In [None]:
# Restart the kernel
import os
os._exit(00)

#### Import the required modules to run the notbook

In [None]:
import boto3
import chromadb
from chromadb.api.types import (
    Documents,
    EmbeddingFunction,
    Embeddings,
)
import json
from langchain import PromptTemplate
import mysql.connector as MySQLdb
import re
from typing import Dict, List, Any
import yaml


### Step 1: Select Hosting Model Service

Here, you can select to run this notebook using SageMaker JumpStart or Amazon Bedrock.

In [None]:
%%time

from sagemaker.jumpstart.model import JumpStartModel

llama4_scout_id = "meta-vlm-llama-4-scout-17b-16e-instruct"
DEFULT_LLM_MODEL_ID = llama4_scout_id
    
model = JumpStartModel(model_id=DEFULT_LLM_MODEL_ID, instance_type="ml.p5.48xlarge")
    
llm_predictor = model.deploy(accept_eula=True)
    
print(f"\nLLM SageMaker Endpoint Name: [{llm_predictor.endpoint_name}].\n")

# Deploy BGE Large EN embedding model on Amazon SageMaker JumpStart:
# Specify the model ID for the HuggingFace BGE Large EN Embedding model
DEFAULT_EMBEDDING_MODEL_ID = "huggingface-sentencesimilarity-bge-large-en"

text_embedding_model = JumpStartModel(model_id=DEFAULT_EMBEDDING_MODEL_ID, instance_type="ml.g5.4xlarge" )
    
embedding_predictor = text_embedding_model.deploy()
    
print(f"\nLLM SageMaker Endpoint Name: [{embedding_predictor.endpoint_name}].\n")

In [None]:
# This notebook example uses Amazon RDS MySQL DB.
llm_selected_db = "mysql"

### Step 2: Get Database and Schema details

Here, we retrieve the services that are already deployed as a part of the cloudformation template to be used in building the application. The services include,

+ Secret ARN with RDS for MySQL Database credentials
+ Database Endpoints

In [None]:
stackname = "l4-txt2sql"  # If your stack name differs from "text2sql", please modify.

In [None]:
cfn = boto3.client('cloudformation')

response = cfn.describe_stack_resources(
    StackName=stackname
)
cfn_outputs = cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']

# Get rds secret arn and database endpoint from cloudformation outputs
for output in cfn_outputs:
    if 'SecretArnMySQL' in output['OutputKey']:
        mySQL_secret_id = output['OutputValue']

    if 'DatabaseEndpointMySQL' in output['OutputKey']:
        mySQL_db_host = output['OutputValue']


In [None]:
secrets_client = boto3.client('secretsmanager')

# Get MySQL credentials from Secrets Manager
credentials = json.loads(secrets_client.get_secret_value(SecretId=mySQL_secret_id)['SecretString'])

# Get password and username from secrets
mySQL_db_password = credentials['password']
mySQL_db_user = credentials['username']


Establish the database connection (MySQL DB)

In [None]:
mySQL_db_conn = MySQLdb.connect(
    host=mySQL_db_host,
    user=mySQL_db_user,
    password=mySQL_db_password
)

#### Load table schema settings

In [None]:
def load_settings(file_path):
    """
    Reads a YAML file and returns its contents as a Python object.

    Args:
        file_path (str): The path to the YAML file.

    Returns:
        obj: The contents of the YAML file as a Python object.
    """
    try:
        with open(file_path, 'r') as file:
            data = yaml.safe_load(file)
        return data
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' does not exist.")
    except yaml.YAMLError as exc:
        print(f"Error: Failed to parse the YAML file '{file_path}': {exc}")

In [None]:
# MySQL Table Setup

# Load table settings - database: healthcare_db | table_name: patients
# Use the confirmed path that exists
settings_patients = load_settings('./schemas/patients_ms.yml')  # This path works
if settings_patients is not None:
    table_patients = settings_patients['table_name']
    table_schema_patients = settings_patients['table_schema']
    db_patients = settings_patients['database']
    print(f"Successfully loaded patients settings: {db_patients}.{table_patients}")
else:
    print("Failed to load patients settings")

# Load table settings - database: insurance_db | table_name: providers
settings_providers = load_settings('./schemas/providers_ms.yml')  # Use same convention
if settings_providers is not None:
    table_providers = settings_providers['table_name']
    table_schema_providers = settings_providers['table_schema']
    db_providers = settings_providers['database']
    print(f"Successfully loaded providers settings: {db_providers}.{table_providers}")
else:
    print("Failed to load providers settings")

# Load table settings for combined schema
settings_patients_providers = load_settings('./schemas/patients-providers_ms.yml')  # Use same convention
if settings_patients_providers is not None:
    print("Successfully loaded combined settings")
else:
    print("Failed to load combined settings")

### Step 3: Create helper functions

To facilate the usability and readability of the SQL Query Analysis made by Llama 4, let's create a suite of helper functions.

##### Chat Completion (Invoke LLM and Return response)

The `sagemaker_chat_completion` function uses the SageMaker Endpoint to invoke the LLMs. The response from the LLM is extracted and returned as text.

In [None]:
def sagemaker_chat_completion(
    messages: list,
    max_gen_len: int = 512,
    temperature: float = 0.1,
    top_p: float = 0.1
) -> str:
    """
    Generates a chat completion using the OpenAI Chat Completions API format.

    Args:
        messages (list): List of message dictionaries with 'role' and 'content' keys.
        max_gen_len (int, optional): The maximum length of the completion.
        temperature (float, optional): Sampling temperature for the model.
        top_p (float, optional): Top p sampling ratio for the model.

    Returns:
        str: The generated text completion.
    """
    # Create the request payload using the exact OpenAI Chat Completions format
    payload = {
        "messages": messages,
        "max_tokens": max_gen_len,
        "temperature": temperature,
        "top_p": top_p
    }

    # Call the model API to generate the completion
    response = llm_predictor.predict(payload)
    
    # Check for different response formats
    if isinstance(response, dict):
        if 'choices' in response and len(response['choices']) > 0:
            # Standard OpenAI format
            return response['choices'][0]['message']['content'].strip()
        elif 'generated_text' in response:
            # Format we tried before
            return response['generated_text'].strip()
    
    # If we can't extract a valid response
    print(f"Unexpected response format: {response}")
    return str(response)

The Function `sagemaker_chat_completion` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results. This uses Amazon SageMaker to invoke the API.

##### Query execution and LLM calling

The `execute_query` function will execute SQL queries, typically for retrieving data from a database, and format the results as a string for further processing or display. 

In [None]:
def execute_query(query: str, db_conn) -> str:
    """Execute an SQL query on the database connection and return the results as a string.

    Args:
        query (str): SQL query to execute
        db_conn (Connection object): Connection object to the database where the query needs to be executed.

    Returns:
        str: A formatted string containing the SQL results.
    """
    # Get a cursor from the database connection
    mycursor = db_conn.cursor()

    # Execute the SQL query
    mycursor.execute(query)

    # Fetch all result rows
    result_rows = mycursor.fetchall()

    # Convert result to string with newline between rows
    output_text = '\n'.join([str(x) for x in result_rows])
    return output_text

The Function `get_llm_sql_analysis` generates and executes an SQL query for a given question, and returns a comprehensive analyzes based on the sql query results.

In [None]:
def get_llm_sql_analysis(question: str, sql_sys_prompt: str, qna_sys_prompt: str, DatabaseType: str) -> str:
    """
    Generates an SQL query based on the given question, executes it, and returns an analysis.

    Args:
        question (str): The input question for which an SQL query needs to be generated.
        sql_sys_prompt (str): The prompt to be used for generating the SQL query.
        qna_sys_prompt (str): The prompt to be used for analyzing the SQL query results.
        DatabaseType (str): The type of database to query (mysql/postgresql).

    Returns:
        str: The analysis of the SQL query results.
    """
    try:
        # *****************************************************************************************
        # 1. Generate SQL Query
        # *****************************************************************************************
        print(f"\nUsing SageMaker to generate SQL for [{DatabaseType}]\n")
        
        # Create messages for SQL generation using OpenAI format
        sql_messages = [
            {"role": "system", "content": sql_sys_prompt},
            {"role": "user", "content": question}
        ]
        
        completion = sagemaker_chat_completion(messages=sql_messages)
        print(f"completion = \n{completion}\n")
        
        # *****************************************************************************************
        # 2. Extract SQL and Execute
        # *****************************************************************************************
        # Extract SQL query using regex patterns
        pattern = r"<sql>(.*)</sql>"
        sr = re.search(pattern, completion, re.DOTALL)

        if sr is None:
            pattern = r"```sql(.*)```"
            sr = re.search(pattern, completion, re.DOTALL)
            
        if sr is None:
            raise ValueError("Could not extract SQL query from completion")

        llm_sql_query = sr.group(1).strip()
        print(f"\nLLM SQL Query: \n{llm_sql_query}")
    
        # Route to appropriate database connection
        match DatabaseType:
            case "mysql":
                db_conn = mySQL_db_conn
            case "postgresql":
                db_conn = pg_db_conn

        # Execute SQL query
        sql_results = execute_query(llm_sql_query, db_conn)
        print(f"\nsql_results = \n{sql_results}")

        # *****************************************************************************************
        # 3. Evaluate the response
        # *****************************************************************************************
        print(f"\nCalling LLM on SageMaker to analyze the results\n")
        
        # Create messages for analysis using OpenAI format
        analysis_prompt = qna_sys_prompt.format(query_results=sql_results, question=question)
        analysis_messages = [
            {"role": "system", "content": "You are a helpful assistant that analyzes SQL query results."},
            {"role": "user", "content": analysis_prompt}
        ]
        
        llm_sql_analysis = sagemaker_chat_completion(messages=analysis_messages)
        print(f"\nLLM SQL Analysis: \n{llm_sql_analysis}")

        return llm_sql_analysis
    except Exception as e:
        print(f"\nException Encountered: \n{e}\n")
        return e

##### Embedding Functions

The Class `AmazonSageMakerEmbeddingFunction` initializes an embedding function with `Amazon SageMaker BGE Large` from JumpStart that integrates with ChromaDB . This class can be further extended to add support for other embedding models available on Amazon SageMaker.

In [None]:
class AmazonSageMakerEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
        self,
        **kwargs: Any,
    ):
        """Initialize AmazonSageMakerEmbeddingFunction.

        Args:
            **kwargs: Additional arguments to pass to the sagemaker embedding function.

        Example:
            >>> sagemaker = AmazonBedrockEmbeddingFunction()
            >>> text_inputs = ["Hello, world!", "How are you?"]
            >>> embeddings = sagemaker(texts)
        """


    def __call__(self, input: Documents) -> Embeddings:
        accept = "application/json"
        content_type = "application/json"

        embeddings = []
        for text in input:
            input_body = {"text_inputs": text, "mode": "embedding"}
            body = json.dumps(input_body).encode('utf-8')
            response = embedding_predictor.predict(
                body,
                {
                    "ContentType": content_type,
                    "Accept": accept,
                }
            )
            embedding = response.get("embedding")
            embeddings.append(embedding)
        return embeddings

## Few-shot text-to-SQL powered by ChromaDB

We will use ChromaDB and the few-shot technique to retrieve table schemas for better performance and generalization across different databases and query types.

### Schema Retrieval

In this approach, we will store only the table schemas in ChromaDB.

### Step 1: Data Preprocessing

The first step is to preprocess the data and create a document that will be ingested into ChromaDB. The final doc clearly separates the table schemas by using XML tags such as `<table_schema></table_schema>`.

In [None]:
# For MySQL Healthcare Database
# The doc includes a structure format for clearly identifying the table schemas
doc1 = "<table_schemas>\n"
doc1 += f"<table_schema>\n {table_schema_patients} \n</table_schema>\n".strip()
doc1 += "\n</table_schemas>"

# The doc includes a structure format for clearly identifying the table schemas
doc2 = "<table_schemas>\n"
doc2 += f"<table_schema>\n {table_schema_providers} \n</table_schema>\n".strip()
doc2 += "\n</table_schemas>"

# The doc includes a structure format for clearly identifying the combined table schemas
doc3 = "<table_schemas>\n"
if settings_patients_providers is not None and 'table_schemas' in settings_patients_providers:
    for table_schema in settings_patients_providers['table_schemas']:
        doc3 += f"<table_schema>\n {table_schema} \n</table_schema>\n"
else:
    # If the combined schema doesn't have a table_schemas list, include both individual schemas
    doc3 += f"<table_schema>\n {table_schema_patients} \n</table_schema>\n"
    doc3 += f"<table_schema>\n {table_schema_providers} \n</table_schema>\n"
doc3 += "\n</table_schemas>".strip()


### Step 2: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB.

In [None]:
chroma_client = None

# Setup Chroma in-memory, for easy prototyping.
chroma_client = chromadb.Client()


In [None]:
# Delete collection if exists
try:
    chroma_client.get_collection(name="table-schemas-default-embedding")
except ValueError:
    # Collection does not exist
    pass
else:
    chroma_client.delete_collection(name="table-schemas-default-embedding")


In [None]:
# For MySQL Healthcare Database
# Create collection using ChromaDB's internal embedding function
collection1 = chroma_client.get_or_create_collection(name="healthcare-table-schemas-default-embedding", 
                                                     metadata={"hnsw:space": "cosine"})

# Add docs to the collection.
collection1.add(
    documents=[
        doc1,
        doc2,
        doc3
    ],
    metadatas=[
        {"source": "mysql", "database": db_patients, "table_name": table_patients},
        {"source": "mysql", "database": db_providers, "table_name": table_providers},
        {"source": "mysql", "database": f"{db_patients}-{db_providers}", "table_name": f"{table_patients}-{table_providers}" }
    ],
    ids=[table_patients, table_providers, f"{table_patients}-{table_providers}"], # unique for each doc
)

pg_collection1 = None

In [None]:
def RunPrompts(DatabaseType: str, question: str, mySQLCol, pgCol):
    """
    Helper function that does the following:
    - Picks up the ChromaDB collection (passed as input parameters) based on the DatabaseType
    - Retrieves relevant table schemas from ChromaDB based on the Business Question.
    - Invoke SageMaker LLM to return the SQL query for the table schemas retrieved.
    - Run the query against the database.
    - Invoke SageMaker LLM to analyze results of the query execution against the Business Question.

    Args:
        DatabaseType (str): mysql / postgresql
        question (str): User's business question
        mySQLCol (object): ChromadB collection for mySQL Database
        pgCol (object): ChromadB collection for PostgreSQL Database

    Returns:
        str: The analysis of the SQL query results provided by the language model.
    """
    # Route the query according to the database passed
    if DatabaseType == "mysql":
        # For MySQL DB
        collection_to_use = mySQLCol
        db_conn = mySQL_db_conn
    elif DatabaseType == "postgresql":
        # For PostgreSQL DB
        collection_to_use = pgCol
        db_conn = pg_db_conn
    else:
        raise ValueError(f"Unsupported database type: {DatabaseType}")
    
    # Query/search 1 most similar results.
    docs = collection_to_use.query(
        query_texts=[question],
        n_results=1
    )

    pattern = r"<table_schemas>(.*)</table_schemas>"
    table_schemas = re.search(pattern, docs["documents"][0][0], re.DOTALL).group(1)
    print(f"ChromaDB - Schema Retrieval: \n{table_schemas.strip()}")
    
    # Format the SQL system prompt
    SQL_SYS_PROMPT = tmp_sql_sys_prompt.format(
        question=question,
        table_schemas=table_schemas,
        dbtype=DatabaseType
    )

    # Get the SQL query from the LLM
    print(f"Using SageMaker to generate SQL for [{DatabaseType}]")
    llm_messages = [
        {"role": "system", "content": SQL_SYS_PROMPT},
        {"role": "user", "content": question}
    ]
    llm_output = sagemaker_chat_completion(llm_messages)
    print("\ncompletion = ")
    print(llm_output)
    
    # Extract SQL query from LLM output
    sql_pattern = r"<sql>(.*?)</sql>"
    sql_matches = re.findall(sql_pattern, llm_output, re.DOTALL)
    
    if not sql_matches:
        # Try alternative formatting with code blocks
        sql_pattern = r"```sql\s*(.*?)\s*```"
        sql_matches = re.findall(sql_pattern, llm_output, re.DOTALL)
        
    if not sql_matches:
        print("⚠️ Warning: No SQL query found in LLM response. Using full response.")
        llm_sql_query = llm_output
    else:
        llm_sql_query = sql_matches[0].strip()
    
    print("\nLLM SQL Query: ")
    print(llm_sql_query)
    
    # Validate the SQL query
    filter_keywords = ["specific", "medical", "certain", "particular", "only", "with"]
    classification_keywords = ["classify", "categorize", "group", "type", "category"]
    
    # Check if the query might be missing filtering
    if ("WHERE" not in llm_sql_query.upper() and 
        any(keyword in question.lower() for keyword in filter_keywords) and
        not any(keyword in question.lower() for keyword in ["all", "every", "total"])):
        print("⚠️ Warning: The query may be missing filtering criteria implied in the question.")
    
    # Check if multiple statements but only some will execute
    if llm_sql_query.count(";") > 1:
        print("⚠️ Note: Multiple SQL statements detected. Ensure all necessary statements are executed.")
    
    # Execute the SQL query
    try:
        sql_results = execute_query(llm_sql_query, db_conn)
        print("\nsql_results = ")
        print(sql_results)
    except Exception as e:
        print(f"⚠️ Error executing SQL: {str(e)}")
        sql_results = f"Error: {str(e)}"
    
    # Analyze the results
    analysis_prompt = QNA_SYS_PROMPT.format(
        query_results=sql_results,
        question=question,
        sql_query=llm_sql_query  # Include the SQL query for context
    )
    
    analysis_messages = [
        {"role": "system", "content": "You are a data analyst providing insights from SQL query results."},
        {"role": "user", "content": analysis_prompt}
    ]
    
    analysis = sagemaker_chat_completion(analysis_messages)
    print("\nAnalysis: ")
    print(analysis)
    
    return analysis

### Step 3: Create a Few-Shot Prompt

Here, we design our prompt template that will account for our question and answer, and formatted correctly for use with Llama 4 models.

First, we create a `system prompt` containing two parts:

1. `table_schemas`. This is a description of the structure of the database table(s), including the name of the table, the names of the columns within each table, and the data types of each column. This information helps Llama 4 to understand the organization and contents of the table.

2. `question`. This is the specific request or information that the user wants to obtain from the table.

By including both the table schema and the user's question in the system prompt, we provide Llama 4 model a complete understanding of the table structure and the user's desired output.

Now, we'll use a few-shot approach using the retrieved tables from ChromaDB.

First, we create a `system prompt` containing a placeholder including any number of table schemas for ChromaDB. 

In [None]:
instructions = [
    {
        "role": "system",
        "content": 
        """You are a {dbtype} query expert whose output is a valid SQL query.

Only use the following tables:

It has the following schemas:
<table_schemas>
{table_schemas}
</table_schemas>

Important guidelines:
1. Always combine the database name and table name to build your queries.
2. Only make assumptions based on the schema provided - do not invent columns that aren't in the schema.
3. If the query requires filtering (e.g., "medical coverage"), ensure you include appropriate WHERE clauses.
4. If the question cannot be fully answered with the provided schema, explicitly state what additional information would be needed.
5. For multi-part questions, ensure you address all parts in your queries.

Please construct a valid SQL statement to answer the following question, return only the {dbtype} query in between <sql></sql>.
"""
    },
    {
        "role": "user",
        "content": "{question}"
    }
]

# Extract prompts directly
tmp_sql_sys_prompt = instructions[0]['content']
sysPt = instructions[0]['content']
userPt = instructions[1]['content']


Next, we create a new `system prompt` containing two parts:

1. `query_results` represents the SQL query results after executing the prompt `tmp_sql_sys_prompt`. This is the raw data that Llama 4 model will use to generate its analysis.

2. `question`. This specifies the type of analysis or insight that the user wants Llama 4 model to provide based on the SQL query results.

By combining the SQL query results and the user's question into a single system prompt, we provide Llama 4 model all the information it needs to understand the context and provide a comprehensive analysis tailored to the user's request.

In [None]:
instructions = [
    {
        "role": "user",
        "content": """Given the following SQL query results:
{query_results}

And the original question:
{question}

Please provide an analysis and interpretation of the results to answer the original question.
"""
    }
]

# Extract the prompt directly
QNA_SYS_PROMPT = instructions[0]['content']
sysPtI = ""  # No system prompt in this instructions list
userPtI = instructions[0]['content']

### Step 4: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be used for the SQL llm analysis.

In [None]:
%%time
# Business question
question = "What is the total count of providers?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

In [None]:
%%time
# Business question
question = "What is the total count of patients?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

In [None]:
%%time
# Business question
question = "What is the total count of patients per provider?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

## Enhanced Schema Retrieval with an Embedding Model

In this approach, we will use an embedding model from AWS.

### Step 1: Ingest docs into ChromaDB

After the data is preprocessed, the next step is to ingest all `docs` into ChromaDB using your selected embedding model (`Amazon SageMaker BGE Large English` or `Amazon Titan Embedding V2`).

In [None]:
# Create embedding function with AWS
aws_ef = AmazonSageMakerEmbeddingFunction()

In [None]:
# Delete collection if exists
try:
    chroma_client.get_collection(name="table-schemas-aws-embedding-model")
except ValueError:
    # Collection does not exist
    pass
else:
    chroma_client.delete_collection(name="table-schemas-aws-embedding-model")


In [None]:
try:
    # Create a more comprehensive combined schema with clear table relationships
    combined_schema = f"""
    /* Database: {db_patients} */
    {table_schema_patients}
    
    /* Database: {db_providers} */
    {table_schema_providers}
    
    /* Relationships */
    The {db_patients}.{table_patients} table has a foreign key Insurance_id that references {db_providers}.{table_providers}(Insurance_id).
    This relationship indicates which provider covers which patient.
    """
    
    collection2.add(
        documents=[
            table_schema_patients,
            table_schema_providers,
            combined_schema
        ],
        metadatas=[
            {"source": "mysql", "database": db_patients, "table_name": table_patients},
            {"source": "mysql", "database": db_providers, "table_name": table_providers},
            {"source": "mysql", "database": f"{db_patients}-{db_providers}", "table_name": f"{table_patients}-{table_providers}" }
        ],
        ids=[table_patients, table_providers, f"{table_patients}-{table_providers}"], # unique for each doc
    )
    print("Successfully added documents to collection")
except Exception as e:
    print(f"Error adding documents to collection: {str(e)}")

### Step 2: Execute Few Shot Prompts

In this example, we expect the table `airplanes` to be included for the SQL llm analysis.

In [None]:
%%time
# Business question
question = "What is the total count of providers?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

In [None]:
%%time
# Business question
question = "What is the total count of patients per provider?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

For this example, we expect the table `airplane` to be included for the SQL llm analysis.

In [None]:
%%time
# Business question
question = "How many unique providers are represented in the database?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results

For this example, we expect the table `flights` to be included for the SQL llm analysis.

In [None]:
%%time
# Business question
question = "How many providers provide medical coverage?"

# Call the updated RunPrompts function without ServiceType parameter
results = RunPrompts(
    DatabaseType=llm_selected_db, 
    question=question, 
    mySQLCol=collection1, 
    pgCol=pg_collection1
)
results


## Clean Up Resources

In [None]:
%%time
# Delete resources
try:
    llm_predictor.delete_model()
    llm_predictor.delete_endpoint()
    print("LLM resources deleted successfully")
except Exception as e:
    print(f"Error deleting LLM resources: {e}")

try:
    embedding_predictor.delete_model()
    embedding_predictor.delete_endpoint()
    print("Embedding resources deleted successfully")
except Exception as e:
    print(f"Error deleting embedding resources: {e}")

## Distributors
- AWS
