# Simplify natural language query using Anthropic Claude on Amazon Bedrock
---

In this notebook, we will explore on how to use multi-modal capabilities from **Anthropic Claude** foundation model, which is available on **Amazon Bedrock**.

By feeding the entity relationship diagram (ERD) on image channel and the user's question onto text prompt channel, LLM can generate the SQL statement for querying the data lake by using **Amazon Athena** service.



## Set up
---

### Upgrade boto3

Ensure our environment is using the most up-to-date `boto3` library.

In [None]:
%pip install --upgrade boto3 --quiet --root-user-action=ignore

### Define logger

In [None]:
import logging

logging.basicConfig(
    format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', 
    level=logging.INFO
)
logger = logging.getLogger(__name__)

## Dataset
---
The dataset is currently in csv file format and is available in `../data` folder, please note that this dataset is synthetic. The entity relationship diagram (ERD) is also available in `../schema_img` folder, and needs for our foundation model in the next step.

<img src='../schema_img/schema.png' alt='dataset schema'>

In [None]:
import os

data_dir = '../data'
file_list = [filename for filename in os.listdir(data_dir) if '.csv' in filename]
logger.info(file_list)

### Upload dataset to Amazon S3

In [None]:
import boto3
from botocore.exceptions import ClientError
import botocore
import sagemaker
import json
from typing import Optional, Dict, Tuple, List

boto_session = boto3.session.Session()
region_name = boto_session.region_name
s3_client = boto3.client(service_name='s3', region_name=region_name)
s3_bucket_name = sagemaker.Session().default_bucket()  # change this to your S3 bucket of choices
s3_prefix_name = 'sample-datasets/raw/consulting-company'  # change this to your S3 prefix of choices

In [None]:
for file in file_list:
    logger.info('Uploading {0}/{1} ...'.format(data_dir, file))
    file_ext = file.split('.')[-1]
    file_nm = file.split('.')[0]
    try:
        with open('{0}/{1}'.format(data_dir, file), 'rb') as data:
            s3_client.upload_fileobj(
                data,
                s3_bucket_name, 
                '{0}/{1}/{2}'.format(s3_prefix_name, file_nm, file)
            )
    except Exception as e:
        logger.error('Something is wrong! {}'.format(e))
        raise(e)
        
logger.info('completed!')

### Set up Glue Data Catalog
---

In this section, I will create Glue Data Catalog based on the uploaded CSV files:

1. Glue database which will host the data from Glue Crawler
2. Set up IAM role for Glue Crawler
3. Create AWS Glue Crawler
4. Start the crawler

In [None]:
glue_db_name = 'demo-nlq-db'  # change this accordingly to your choices
glue_client = boto_session.client(
    service_name='glue',
    region_name=region_name,
)

try:
    logger.info('Attempting to create DB: {}'.format(glue_db_name))
    create_db_resp = glue_client.create_database(
        DatabaseInput={
            'Name': glue_db_name,
            'Description': 'Sample DB for NLQ use case'
        }
    )
    logger.info('Finish creating DB')
    
except ClientError as e:
    logger.error(e)
    get_db_resp = glue_client.get_databases()
    glue_db_name = [db['Name'] for db in get_db_resp['DatabaseList'] if db['Name'] == glue_db_name][0]

logger.info('Database to use: {}'.format(glue_db_name))

### Create Glue Crawler

In [None]:
import json
import time

iam_client = boto_session.client(
    service_name='iam', 
    region_name=region_name
)
glue_role_name = 'demo-glue-nlq'

try:
    logger.info('Attempt to create IAM role')
    assume_role_policy_doc = {
        'Version': '2012-10-17',
        'Statement': [{
            'Effect': 'Allow',
            'Action': 'sts:AssumeRole',
            'Principal': {
                'Service': 'glue.amazonaws.com'
            }
        }]
    }
    assume_role_policy_doc_json = json.dumps(assume_role_policy_doc)
    logger.info('Creating {} ...'.format(glue_role_name))
    glue_iam_role = iam_client.create_role(
        RoleName=glue_role_name,
        AssumeRolePolicyDocument=assume_role_policy_doc_json,
    )
    time.sleep(10)
    
except ClientError as e:
    logger.error(e)
    glue_iam_role = iam_client.get_role(RoleName=glue_role_name)

policy_arns = [
    'arn:aws:iam::aws:policy/AWSGlueConsoleFullAccess',
    'arn:aws:iam::aws:policy/AmazonS3FullAccess',
    'arn:aws:iam::aws:policy/CloudWatchLogsFullAccess',
]

for policy_arn in policy_arns:
    iam_client.attach_role_policy(
        RoleName=glue_role_name,
        PolicyArn=policy_arn
    )
    time.sleep(5)
    
logger.info('completed!')

In [None]:
sts_client = boto_session.client(service_name='sts', region_name=region_name)
aws_account_id = sts_client.get_caller_identity()['Account']
glue_crawler_name = 'demo-nlq-crawler'

try:
    logger.info('Attempting to create crawler name: {}'.format(glue_crawler_name))
    glue_client.create_crawler(
        Name=glue_crawler_name,
        Role=glue_role_name,
        DatabaseName=glue_db_name,
        Targets={
            'CatalogTargets': [],
            'DeltaTargets': [],
            'DynamoDBTargets': [],
            'HudiTargets': [],
            'IcebergTargets': [],
            'JdbcTargets': [],
            'MongoDBTargets': [],
            'S3Targets': [{
                'Exclusions': [],
                'Path': 's3://{0}/{1}/'.format(s3_bucket_name, s3_prefix_name)
            }],
        },
        Classifiers=[],
        Configuration='{"Version": 1.0, "CreatePartitionIndex": true}',
        LakeFormationConfiguration={
            'AccountId': aws_account_id,
            'UseLakeFormationCredentials': False
        },
        RecrawlPolicy={
            'RecrawlBehavior': 'CRAWL_EVERYTHING'
        },
        LineageConfiguration={
            'CrawlerLineageSettings': 'ENABLE',
        },
    )

except ClientError as e:
    logger.error(e)

### Start Glue Crawler
---

This should take no more than 4 minutes to finish.

In [None]:
crawler_resp = glue_client.get_crawler(
    Name=glue_crawler_name
)

if crawler_resp['Crawler']['State'] == 'READY':
    logger.info('Start crawler...')
    resp_ = glue_client.start_crawler(Name=glue_crawler_name)
    time.sleep(240)
    logger.info('Crawl should be complete....')
    

We can check the last crawl status by using below command.

In [None]:
logger.info('Last crawl status: {}'.format(
    glue_client.get_crawler(Name=glue_crawler_name)['Crawler']['LastCrawl']['Status']
))

## Anthropic Claude on Amazon Bedrock

<div class="alert alert-block alert-info">
    <b>Prerequisite:</b> Ensure that you have model access on Amazon Bedrock console page.
</div>


### Check model ID 
---
We can list the available foundation models by using `list_foundation_models()`, for this demonstration I will pick Sonnet foundation model.

In [None]:
bedrock_client = boto_session.client(
    service_name='bedrock', 
    region_name=region_name
)

[model['modelId'] for model in bedrock_client.list_foundation_models()['modelSummaries'] \
   if ('claude-3' in model['modelId'].lower()) & ('ON_DEMAND' in model['inferenceTypesSupported']) ]

In [None]:
sonnet_model_id = 'anthropic.claude-3-sonnet-20240229-v1:0'

### Invoke Claude foundation model on Amazon Bedrock using boto3 SDK
---

In this example, I will use `converse` API to call **Claude Sonnet foundation model** for more details on converse API, please refer to this [documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html). 

**Converse** API provides a consistent interface that works with all models that support messages. This allows you to write code once and use it with different models. If a model has unique inference parameters, you can also pass those unique parameters to the model.

In [None]:
def converse_sql_generator(
    user_question: str,
    model_id: str,
    max_tokens: int=3000,
    temperature: float=0.,
    full_image_filename: str='../schema_img/schema.png',
    boto_session: boto3.session.Session=boto_session
) -> dict:

    with open(full_image_filename, 'rb') as image_file:
        image_bytes = image_file.read()
    
    bedrock_runtime_client = boto_session.client(
        service_name='bedrock-runtime',
        region_name=boto_session.region_name,
    )
    
    message = {
        'role': 'user',
        'content': [{
            'text': '''
            You are a SQL statement generation expert, and are assigned to generate SQL statements executed on Amazon Athena.
            Amazon Athena built on open-source Trino and Presto engines, so your SQL should be executed successfully on Presto.

            You will be given the image of ERD (Entity Relationship Diagram), which represent the relationship of tables, primary, and join key. 

            Read it and ensure you understand the database structures.
            It is IMPORTANT to respect the type of columns: if a column is string, the value should be enclosed in quotes.
            While concatenating a non string column, make sure to cast the column to string.
            For date columns comparing to string, please cast the string input.

            First, you will need to list what tables are in the diagram and what are the join keys for each table!
            Secondly, You will be presented with the question within <question> tag.
            Lastly, generate the SQL to get the answer for the question using relationship from the given image.

            <question>
            {0}
            </question>

            If you cannot generate the SQL from the attached diagram, respond with "Sorry, there're not information to generate the SQL query".
            Once you generate the SQL query, reexamine your SQL again! 

            1. Make sure the columns exist in each table as in the given image!
            2. Make sure the join keys in each table are correct according to the given image!

            Your final answer will be in XML format.
            <result>
            <sql>SQL query</sql>
            <explanation>Explain clearly your approach, what the query does, and its syntax</explanation>
            </result>
            
            '''.format(user_question),
        }, {
            'image': {
                'format': 'png',
                'source': {
                    'bytes': image_bytes
                }
            },
        }],
    }
    
    inference_config = {
        'maxTokens': max_tokens,
        'temperature': temperature
    }
    
    converse_response = bedrock_runtime_client.converse(
        modelId=model_id,
        messages=[message],
        inferenceConfig=inference_config,
    )
    return converse_response


def get_text_output(converse_resp: dict) -> str:
    return converse_resp['output']['message']['content'][0]['text']

In [None]:
test_question = 'How many projects are we having?'
resp = converse_sql_generator(
    model_id=sonnet_model_id,
    user_question=test_question,
)

In [None]:
test_answer = get_text_output(resp)
print(test_answer)

Moreover, you can access tokens usage and latency metrics from the `converse` API output as well.

In [None]:
print(resp['metrics'])
print(resp['usage'])

Let's quickly test our functions.

## Query the data lake using Amazon Athena
---

Because we have crawled the data using **Glue data crawler**, we can use **Amazon Athena** service to query the data.
There are several steps to call and get the query result from Amazon Athena.

1. Start query execution
2. Get the query state
3. Once it is *SUCCEED*, get the query result

Let's create the function to encapsulate these steps!

In [None]:
def get_sql_script(
    string_resp_from_llm: str
) -> str:
    '''Function to get SQL query from <sql> tag
    '''
    sql_output = string_resp_from_llm.split('<sql>')[1].split('</sql>')[0]
    return sql_output


def call_athena(
    sql_script: str,
    glue_db_name: str=glue_db_name,
    boto_session: boto3.session.Session=boto_session
) -> Optional[List[dict]]:
    '''
    Function to call Amazon Athena and use the provided SQL query, and wait to get the query result
    :param sql_script: the SQL script to run against Amazon Athena
    :param boto_session: boto3 session (default: boto_session)
    
    :return: the list of SQL results returning from Amazon Athena
    '''    
    def _start_query_execution_(
        sql_script: str=sql_script,
        boto_session: boto3.session.Session=boto_session
    ) -> Optional[str]:
        ''' Function to call Amazon Athena and use the provided SQL query
        :param sql_script: the SQL script to run against Amazon Athena
        :param boto_session: boto3 session (default: boto_session)
        
        :return: the execution ID used for further tracking and result retrieval
        '''
        execution_id = None
        athena_client = boto_session.client(service_name='athena', region_name=boto_session.region_name)
        
        try:
            logger.info('Start query execution on Amazon Athena...')
            query_response = athena_client.start_query_execution(
                QueryString=sql_script,
                ResultConfiguration={
                    'OutputLocation': 's3://{}/athena-query-results/'.format(s3_bucket_name),
                    'EncryptionConfiguration': {
                        'EncryptionOption': 'SSE_S3',
                    },
                },
                QueryExecutionContext={
                    'Database': glue_db_name,
                },
            )
            execution_id = query_response["QueryExecutionId"]
            logger.info('SQL script is executing ...')
            logger.info('Query execution ID: {} ...'.format(execution_id))
        
        except ClientError as e:
            logger.error('The provided SQL query has syntax error!!')
            logger.error(e)
            
        return execution_id
    
    
    def _get_query_state_(
        execution_id: str,
        boto_session: boto3.session.Session=boto_session
    ) -> Optional[Tuple[str, dict]]:
        ''' 
        Function to get the query state from Amazon Athena
        Remark: possible query state is 'QUEUED', 'RUNNING', 'SUCCEEDED', 'FAILED', or 'CANCELLED'
        :param execution_id: the execution ID used for further tracking and result retrieval
        :param boto_session: boto3 session (default: boto_session)

        :return: Tuple of the query state and response dictionary containing details (i.e., data scan, run time)
        '''
        athena_client = boto_session.client(service_name='athena', region_name=boto_session.region_name)
        try:
            get_query_state_resp = athena_client.get_query_execution(QueryExecutionId=execution_id)
            _query_state = get_query_state_resp['QueryExecution']['Status']['State']
            logger.info('The current query state is {} ...'.format(_query_state))
            return _query_state, get_query_state_resp['QueryExecution']

        except ClientError as e:
            logger.error('Something went wrong when trying to get query state')
            logger.error(e)
            raise(e)
            
            
    def _get_query_result_(
        execution_id: str,
        boto_session: boto3.session.Session=boto_session
    ) -> Optional[List[dict]]:
        athena_client = boto_session.client(service_name='athena', region_name=boto_session.region_name)
        try:
            logger.info('Retrieving the SQL result ...')
            query_result_resp = athena_client.get_query_results(QueryExecutionId=execution_id)
            logger.info('Finish fetching result =) .... ')
            return query_result_resp['ResultSet']['Rows']
            
        except ClientError as e:
            logger.error('Cannot get SQL result...')
            logger.error(e)
            raise(e)
    
    execution_id = _start_query_execution_()
    _query_state, _query_resp_dict = _get_query_state_(execution_id=execution_id)
    while (_query_state != 'FAILED') and (_query_state != 'SUCCEEDED'):
        time.sleep(5)
        _query_state, _query_resp_dict = _get_query_state_(execution_id=execution_id)
        
    query_result = _get_query_result_(execution_id=execution_id)
        
    return query_result

In [None]:
test_query_result = call_athena(get_sql_script(test_answer))

In [None]:
test_query_result

As we can see the result are returned in more technical manner, we will use this output to feed onto another foundation model to summarize and return results.

## Summarize the output to human-friendly text
---

Ultimately, you would live your generative AI application to respond back in more natural manner instead of technical format like list, or dictionary format.

In [None]:
def converse_answer_generator(
    user_question: str,
    sql_response: List[dict],
    model_id: str,
    max_tokens: int=3000,
    temperature: float=0.,
    boto_session: boto3.session.Session=boto_session
) -> dict:
    logger.info('Generating final response from the provided question and SQL result ...')
    bedrock_runtime_client = boto_session.client(
        service_name='bedrock-runtime',
        region_name=boto_session.region_name,
    )
    ans_generator_template = '''
    You are to answer the question in <question> tag to the best of your ability based on the given context in <context> tag.
    
    The context are kept in Dictionary format within List data type.
    The first element of list is the column, and the second element of is the actual data.
    If there is only one element in the list, this means that there's no result returned.
    As such, the context means there's no or 0 for their question. 
    
    Your response should be PRECISE!!! TRY NOT TO REPEAT THE QUESTION WHEN RESPONSE! 
    AND GIVE ONLY THE ANSWER, No need to reiterate the context!
    
    Do not include information that is not relevant to the question.
    Only provide information based on the context provided, and do not make assumptions!

    <question>
    {question}
    </question>

    <context>
    {context}
    </context
    '''.format(question=user_question, context=sql_response)
    
    message = {
        'role': 'user',
        'content': [{
            'text': ans_generator_template,
        },],
    }
    
    inference_config = {
        'maxTokens': max_tokens,
        'temperature': temperature
    }
    
    converse_response = bedrock_runtime_client.converse(
        modelId=model_id,
        messages=[message],
        inferenceConfig=inference_config,
    )
    return converse_response

In [None]:
resp2 = converse_answer_generator(
    user_question=test_question,
    sql_response=test_query_result,
    model_id=haiku_model_id,
)
print(get_text_output(resp2))

## Put it all together
---
Let's test it several questions.

In [None]:
question_01 = 'How many employees are there?'
question_02 = 'Who are working on each of the projects?'
question_03 = 'What is the least project we are spending time on? And who are in the projects and how much time spent?'
question_04 = 'List top 3 employees who have spending the time on the project the most.'
question_05 = 'Are there any employees working on multiple projects? If yes, what is the project and the time spent on each project?'

question_list = [question_01, question_02, question_03, question_04, question_05]

In [None]:
from IPython.display import Markdown

for question in question_list:
    _sql_resp = converse_sql_generator(
        model_id=sonnet_model_id,
        user_question=question,
    )
    _generate_sql = get_sql_script(get_text_output(_sql_resp))
    _athena_query_result = call_athena(_generate_sql)
    _answer_resp = converse_answer_generator(
        user_question=question,
        sql_response=_athena_query_result,
        model_id=sonnet_model_id,
    )
    _answer = get_text_output(_answer_resp)
    display(Markdown('**Question:** <span style="color: #ff0000">{}</span>'.format(question)))
    display(Markdown('**Answer:** <span style="color: #0000FF">{}</span>'.format(_answer)))
    display(Markdown('**SQL Generated:** <span style="color: #A020F0">{}</span>'.format(_generate_sql)))
    time.sleep(2)

## Clean up
---

Once you are done with experiment, please ensure you have deleted all resources in this demonstration to prevent any incur cost.

### Clean up data uploaded to Amazon S3

In [None]:
for file in file_list:
    file_ext = file.split('.')[-1]
    file_nm = file.split('.')[0]
    logger.info('Deleting {0}/{1}/{2} ...'.format(s3_prefix_name, file_nm, file))
    try:
        _ = s3_client.delete_objects(
            Bucket=s3_bucket_name,
            Delete={
                'Objects': [{
                    'Key': '{0}/{1}/{2}'.format(s3_prefix_name, file_nm, file),
                },],
                'Quiet': True
            },
        )
    except Exception as e:
        logger.error('Something is wrong! {}'.format(e))
        raise(e)



### Detach IAM policies and delete IAM role

In [None]:
for policy_arn in policy_arns:
    iam_client.detach_role_policy(
        RoleName=glue_role_name,
        PolicyArn=policy_arn
    )
    time.sleep(5)

_ = iam_client.delete_role(
    RoleName=glue_role_name
)

### Delete Glue database and Glue Crawler

In [None]:
_ = glue_client.delete_database(
    Name=glue_db_name,
)
_ = glue_client.delete_crawler(
    Name=glue_crawler_name,
)