# Implementing AI Safety: Exploring Contextual Grounding with Amazon Bedrock and Meta Llama Models


## Introduction
---
[**Amazon Bedrock**](https://aws.amazon.com/bedrock/) is a fully managed service that offers a choice of high-performing foundation models (FMs) from leading AI companies through a single API. It provides a comprehensive set of capabilities to build generative AI applications with security, privacy, and responsible AI.

Guardrails in Amazon Bedrock are a crucial feature that allows developers to implement safeguards and controls on language model outputs. These guardrails help ensure that AI-generated content aligns with business policies, maintains brand consistency, and adheres to ethical standards. For more information on Amazon Bedrock guardrails, please kindly refer to the [official AWS documentation on guardrails](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html).

In this notebook, we will explore how to implement **AI safety guardrails**, specifically on <u>contextual grounding check</u> on RAG application, focusing on Meta Llama large language model.


## Setup

In [None]:
%pip install -qU --quiet -r requirements.txt

## Bedrock Guardrail for Contextual grounding

### Create Bedrock Guardrails
---

To create Bedrock Guardrails, you can use `create_guardrail` API from **boto3** client.

In [None]:
import boto3
from botocore.exceptions import ClientError
import json
import time
import uuid
from typing import Tuple, Optional, List, Dict

boto_session = boto3.session.Session()
region_name = boto_session.region_name
bedrock_client = boto_session.client(
    service_name='bedrock',
    region_name=region_name,
)
bedrock_runtime_client = boto_session.client(
    service_name='bedrock-runtime',
    region_name=region_name,
)
guardrail_name = 'contextual_grounding_check'

In [None]:
def check_existing_guardrail(
    guardrail_name: str,
    boto_session: boto3.session
) -> Tuple[bool, str]:
    bedrock_client = boto_session.client(
        service_name='bedrock', 
        region_name=boto_session.region_name
    )
    resp = bedrock_client.list_guardrails()
    for guardrail in resp.get('guardrails', []):
        if guardrail.get('name', '') == guardrail_name:
            print('Guardrail "{}" exists'.format(guardrail_name))
            return True, guardrail.get('id')

    return False, ''


def check_to_delete_guardrail(
    guardrail_name: str,
    boto_session: boto3.session
) -> None:
    bedrock_client = boto_session.client(
        service_name='bedrock', 
        region_name=boto_session.region_name
    )
    exist_ind, _id = check_existing_guardrail(guardrail_name, boto_session)
    if exist_ind:
        print('Deleting existing guardrail')
        _ = bedrock_client.delete_guardrail(guardrailIdentifier=_id)
        time.sleep(5)
        print('Delete completed...')

    else:
        print('No guardrail name "{}" found'.format(guardrail_name))

In [None]:
def create_contextual_grounding_guardrail(
    guardrail_name: str,
    guardrail_desc: str,
    relevance_scor: float,
    grounding_scor: float,
    boto_session: boto3.session,
    tag_list: Optional[List[Dict[str, str]]] = []
) -> dict:
    guardrail_config = {
        'name': guardrail_name,
        'description': guardrail_desc,
        'blockedInputMessaging': 'I am sorry, but I cannot process that input request.',
        'blockedOutputsMessaging': 'I apologize, but I cannot provide this information',
        'contextualGroundingPolicyConfig': {
            'filtersConfig': [{
                'type': 'RELEVANCE',
                'threshold': relevance_scor
            }, {
                'type': 'GROUNDING',
                'threshold': grounding_scor
            }]
        },
        'tags': tag_list
    }
    try:
        bedrock_client = boto_session.client(
            service_name='bedrock',
            region_name=boto_session.region_name,
        )
        create_resp = bedrock_client.create_guardrail(**guardrail_config)
        time.sleep(5)
        print('Create guardrail "{}" completed'.format(guardrail_name))
        return create_resp
    except (ClientError, Exception) as e:
        print('Error creating guardrail: {e}'.format(e=e))
        raise


In [None]:
check_to_delete_guardrail(guardrail_name, boto_session)
create_guardrail_resp = create_contextual_grounding_guardrail(
    guardrail_name=guardrail_name,
    guardrail_desc='Guardrail to detect hallunication from RAG application',
    relevance_scor=.85,
    grounding_scor=.85,
    boto_session=boto_session,
    tag_list=[{
        'key': 'guardrail-policy',
        'value': 'contextual-grounding-check',
    }],
)
guardrail_id = create_guardrail_resp.get('guardrailId')

### Create Guardrail version
---
Once we have defined the guardrail, you should use `create_guardrail_version` API to create a snapshot of the guardrail when you are satisfied with a configuration, testing, or you want to do A/B testing on each configuration with another version.


In [None]:
def create_version_guardrail(
    guardrail_id: str,
    version_desc: str,
    boto_session: boto3.session,
    request_token: str = str(uuid.uuid4())
) -> dict:
    bedrock_client = boto_session.client(
        service_name='bedrock',
        region_name=boto_session.region_name,
    )
    try:
        create_version_resp = bedrock_client.create_guardrail_version(
            guardrailIdentifier=guardrail_id,
            description=version_desc,
            clientRequestToken=request_token
        )
        time.sleep(3)
        return create_version_resp
    except (ClientError, Exception) as e:
        print('Error creating guardrail version: {e}'.format(e=e))
        raise

In [None]:
create_version_resp = create_version_guardrail(
    guardrail_id=guardrail_id,
    version_desc='Version 1 - contextual grounding check only!',
    boto_session=boto_session,
)
guardrail_version_id = create_version_resp.get('version', '')

We need to take note of **guardrail_id** and **guardrail_version_id**, these two parameters are required when using guardrail.

## Apply to RAG application
---

To illustrate the contextual grounding check, we will use `apply_guardrail` API and separate each process for us.

<img src='./img/Bedrock-guardrails_how-it-works.png' alt="how it works apply_guardrails" style='width: 800px;'/>

### Connect to existing vector database
---

Connect to our existing vector DB, and create the retriever.

In [None]:
from langchain_chroma import Chroma
import langchain_aws
import langchain_core
from langchain_aws import BedrockEmbeddings
import chromadb

chroma_db_dir = './../_vector_db'
chroma_collection_name = 'amazon-shareholder-letters'
titan_model_id = 'amazon.titan-embed-text-v2:0'
titan_embedding_fn = BedrockEmbeddings(
    model_id=titan_model_id,
    region_name=boto_session.region_name
)
persistent_client = chromadb.PersistentClient(
    path=chroma_db_dir,
)
vector_store = Chroma(
    collection_name=chroma_collection_name,
    embedding_function=titan_embedding_fn,
    client=persistent_client,
)
chroma_retriver = vector_store.as_retriever(
    search_kwargs={'k': 3}
)

Here, let's ask sample question and query the vector database.

In [None]:
sample_question = '''
Amazon discusses its investments and progress in various areas, such as Generative AI, logistics, and healthcare. 
How do these initiatives relate to the company's strategy of building "primitives" or foundational building blocks, 
and what potential customer experiences or business opportunities do they enable?'''

sources_ref_list = chroma_retriver.invoke(sample_question)
print(sources_ref_list)

### Define apply guardrail
---

The `apply_guardrail` API is available for us to use independently. By providing three main components to the API, it can evaluate based on **grounding** and **relevance** metric.
1. **Query**: this is the input prompt or user input
2. **Context (or grounding source)**: this is the retrieved context or chunks from vector store.
3. **Model response (or guard content)**: this is the output from LLM response, the `apply_guardrail` API will determine if the response (or content) should be guard or not.

<div class="alert alert-block alert-info">
    <b>Note</b>: The contextual grounding check is applied at the <b>OUTPUT</b> after the LLM response, not <b>INPUT</b> for the LLM. Hence, we need to specify <i>SOURCE='OUTPUT'</i> in the API call.
</div>

For more information on the API call, please refer to [boto3 documentation](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/apply_guardrail.html#BedrockRuntime.Client.apply_guardrail).

In [None]:
def call_bedrock_llm(
    input_prompt: str,
    context_str: str,
    bedrock_model_id: str,
    boto_session: boto3.session,
):
    bedrock_runtime_client = boto_session.client(
        service_name='bedrock-runtime',
        region_name=boto_session.region_name,
    )
    messages = [{
        'role': 'user',
        'content': [{
            'text': input_prompt
        }],
    }]

    llm_converse_resp = bedrock_runtime_client.converse(
        modelId=bedrock_model_id,
        system=[{
            'text': '''You are a helpful assistant. Here is the context
            <context>
            {}
            </context>
            '''.format(context_str)
        }],
        messages=messages,
        inferenceConfig={
            'maxTokens': 2048,
            'temperature': 0.1,
            'topP': .85
        }
    )
    return llm_converse_resp


def contextual_grounding_check(
    guardrail_id: str,
    guardrail_version: str,
    bedrock_model_id: str,
    input_prompt: str,
    context_lists: List[langchain_core.documents],
    boto_session: boto3.session,
    verbose: bool = False
):
    context_str = '\n---\n'.join([doc.page_content for doc in context_lists])
    _llm_resp = call_bedrock_llm(
        input_prompt=input_prompt,
        context_str=context_str,
        bedrock_model_id=bedrock_model_id,
        boto_session=boto_session
    )
    _llm_output = _llm_resp.get('output', {}).get('message', {})\
        .get('content', [])[0].get('text', 'NA').strip()
    _bedrock_tokens_usg = _llm_resp.get('usage', {})
    if verbose:
        print('Here is token usage by {}'.format(bedrock_model_id))
        print(json.dumps(_bedrock_tokens_usg, indent=2))

    _guardrail_content_block = []
    for doc in context_lists:
        _ground = {
            'text': {
                'text': doc.page_content,
                'qualifiers': [
                    'grounding_source'
                ]
            }
        }
        _guardrail_content_block.append(_ground)

    if input_prompt:
        _query = {
            'text': {
                'text': input_prompt,
                'qualifiers': [
                    'query'
                ]
            }
        }
        _guardrail_content_block.append(_query)

    if _llm_output:
        _guard_content = {
            'text': {
                'text': _llm_output,
                'qualifiers': [
                    'guard_content'
                ]
            }
        }
        _guardrail_content_block.append(_guard_content)

    bedrock_runtime_client = boto_session.client(
        service_name='bedrock-runtime',
        region_name=boto_session.region_name,
    )
    guardrail_resp = bedrock_runtime_client.apply_guardrail(
        guardrailIdentifier=guardrail_id,
        guardrailVersion=guardrail_version,
        source='OUTPUT',
        content=_guardrail_content_block
    )
    return guardrail_resp, _llm_resp


def hallucination_detection(
    guardrail_response: dict,
    verbose: bool = False
) -> None:
    _guardrail_usg = guardrail_response.get('usage', {})
    _guardrail_action = guardrail_response.get('action', '')
    _guardrail_out = guardrail_response.get('outputs', [])[0].get('text')\
        if len(guardrail_response.get('outputs', [])) > 0 else ''

    if verbose:
        print('guardrail usage:\n{}'.format(json.dumps(_guardrail_usg, indent=2)))
    _assessments = guardrail_response.get('assessments', [])
    for assessment in _assessments:
        if assessment.get('contextualGroundingPolicy'):
            for filter_result in assessment['contextualGroundingPolicy'].get('filters', []):
                if filter_result['type'] == 'RELEVANCE':
                    relevance = filter_result.get('score', 0)
                    relevance_threshold = filter_result.get('threshold', 0)
                elif filter_result['type'] == 'GROUNDING':
                    grounding = filter_result.get('score', 0)
                    grounding_threshold = filter_result.get('threshold', 0)

            if relevance < relevance_threshold or grounding < grounding_threshold:
                return True, relevance, grounding, relevance_threshold, grounding_threshold, _guardrail_action, _guardrail_out
    
    return False, relevance, grounding, relevance_threshold, grounding_threshold, _guardrail_action, _guardrail_out

In [None]:
llama3_1_70b_model_id = 'meta.llama3-1-70b-instruct-v1:0'

guardrail_resp, llm_resp = contextual_grounding_check(
    guardrail_id=guardrail_id,
    guardrail_version=guardrail_version_id,
    bedrock_model_id=llama3_1_70b_model_id,
    input_prompt=sample_question,
    context_lists=sources_ref_list,
    boto_session=boto_session
)

### Example execution

In [None]:
import pandas as pd

eval_df = pd.read_csv('../_eval_data/eval_dataframe.csv')

In [None]:
import time
from IPython.display import display, Markdown


for _idx, row in eval_df.iterrows():
    _question = row['input'].split(':')[1].strip() if len(row['input'].split(':')) > 1 else row['input']
    sources_ref_list = chroma_retriver.invoke(_question)
    guardrail_resp, llm_resp = contextual_grounding_check(
        guardrail_id=guardrail_id,
        guardrail_version=guardrail_version_id,
        bedrock_model_id=llama3_1_70b_model_id,
        input_prompt=_question,
        context_lists=sources_ref_list,
        boto_session=boto_session
    )
    _ind, rel_scor, ground_scor, rel_thre, ground_thre, _guardrail_action, \
        _guardrail_output = hallucination_detection(guardrail_resp)
    display(Markdown("<font color='red'>Question: {}</font>".format(_question)))
    display(Markdown("<font color='blue'>LLM Response: {}</font>".format(
        llm_resp.get('output').get('message').get('content')[0].get('text')
    )))
    display(Markdown("<font color='green'>Guardrail Action: {}</font>".format(_guardrail_action)))
    display(Markdown("<font color='green'>Guardrail output: {}</font>".format(_guardrail_output)))
    display(Markdown(
        "<font color='green'>Relevance score ({rel_scor}) vs. threshold({rel_thre})</font>"\
        .format(rel_scor=rel_scor, rel_thre=rel_thre)))
    display(Markdown(
        "<font color='green'>Grounding score ({ground_scor}) vs. threshold({ground_thre})</font>"\
        .format(ground_scor=ground_scor, ground_thre=ground_thre)))
    display(Markdown(' --- '))



## Summary
---

In this notebook, we have demonstrated how to use **Bedrock guardrail** as independent guardrail for hallucination detection. The **contextual grounding feature** proved effective in identifying potential hallucinations by comparing model responses against reference information. Relevance and grounding scores provided quantitative measures to assess the accuracy of model outputs.

- **Grounding Threshold**: This represents the minimum confidence score for a model response to be considered grounded. Responses with scores below this threshold are deemed to contain information not supported by the reference source.

- **Relevance Threshold**: This is the minimum confidence score for a model response to be considered relevant to the user's query. Responses scoring below this threshold are considered off-topic or not addressing the user's question adequately.

The `ApplyGuardrail` API demonstrated the ability to decouple guardrails from specific foundation models, allowing for more versatile and model-agnostic content moderation. This decoupling enables the application of consistent safety measures across different AI models, including custom or third-party foundation models.
