# Sagemaker RAG retrieval and generation with SageMaker Inference and Bedrock Guardrails

This lab demonstrates how to enhance a Retrieval-Augmented Generation (RAG) pipeline by integrating Amazon SageMaker Inference with Amazon Bedrock Guardrails. We will walk through the process of querying a OpenSearch vector knowledge base, using SageMaker for model inference, applying Guardrails to control the generation of responses, and filtering results with metadata to ensure compliance and quality. We will use the same PubMed medical theme generated in the opensearch RAG lab where we will refer to the previously created opensearch vector database with PubMed dataset and show how guardrails can be used to filter the response.

## Overview
This workshop lab guides you through building a secure and compliant Retrieval-Augmented Generation (RAG) pipeline using Amazon SageMaker for inference, OpenSearch for vector-based retrieval, and Amazon Bedrock Guardrails for response filtering. You will learn to:
- Configure your AWS environment and required libraries.
- Set up and publish Bedrock Guardrails to enforce compliance, safety, and contextual relevance.
- Connect and query an AWS OpenSearch vector database populated with PubMed medical data (See lab 03 sagemaker-opensearch-rag for the opensearch database used in this lab).
- Integrate SageMaker endpoints for language model inference. (See lab 03 sagemaker-opensearch-rag for the sagemaker endpoint details which are reused here.)
- Apply guardrails to restrict inappropriate, unsafe, or non-compliant outputs, including blocking certain topics and anonymizing sensitive information.
- Validate the effectiveness of guardrails using real-world medical and compliance scenarios.

## Pre-requisites
- Completion of previous labs (especially OpenSearch RAG ingestion).
- Access to AWS SageMaker, OpenSearch, and appropriate IAM permissions.
- Installation of Python libraries: opensearch-py, langchain, boto3, requests_aws4auth.

### Install libraries 
Prepare your environment with all necessary libraries and AWS credentials.
- Install the required Python libraries by running the provided pip install commands for langchain, boto3, opensearch-py, requests-aws4auth, and certifi.
- Import all necessary modules (boto3, json, etc.).
- Confirm that your AWS credentials and permissions are set up, and that you have access to the required SageMaker and OpenSearch resources.

In [None]:
%pip install langchain boto3 -q
%pip install opensearch-py
%pip install requests-aws4auth
%pip install certifi
print("Installs completed.")

In [None]:
import boto3
import boto3
import json
from botocore.exceptions import ClientError, BotoCoreError

## 1. Set configuration variables
Set up all configuration variables for OpenSearch and SageMaker. Note `%store -r variable_name` retrieves a previously stored variable from previous labs.
You will replace the placeholder values for:
   - `aos_host` with your OpenSearch domain endpoint (note: without `https://`).
   - `index_name` with the name of your OpenSearch index from the ingestion lab.
   - `SAGEMAKER_LLM_ENDPOINT_NAME` with your SageMaker endpoint name.
- Initialize your AWS session and retrieve your IAM role ARN and region.
- Initialize the Bedrock client for your AWS region.

In [None]:
%store -r OS_DOMAIN_NAME
%store -r AOS_HOST
%store -r OPENSEARCH_INDEX_NAME
%store -r EMBEDDING_MODEL_NAME
%store -r EMBED_ENDPOINT_NAME
%store -r GENERATION_ENDPOINT_NAME

In [None]:
print(f"OS_DOMAIN_NAME:{OS_DOMAIN_NAME}")
print(f"AOS_HOST:{AOS_HOST}")
print(f"OPENSEARCH_INDEX_NAME:{OPENSEARCH_INDEX_NAME}")
print(f"EMBEDDING_MODEL_NAME:{EMBEDDING_MODEL_NAME}")
print(f"EMBED_ENDPOINT_NAME:{EMBED_ENDPOINT_NAME}")
print(f"GENERATION_ENDPOINT_NAME:{GENERATION_ENDPOINT_NAME}")

In [None]:
# Opensearch Configuration
OPENSEARCH_URL = f"https://{AOS_HOST}"
service = "es"  
port = 443 

# Sagemaker configuration
session = boto3.Session()
sts_client = boto3.client('sts')
# Get caller identity
caller_identity = sts_client.get_caller_identity()

# Extract and print the IAM role ARN
iam_role_arn = caller_identity["Arn"]
account_id = sts_client.get_caller_identity().get('Account')
region = session.region_name

print("Session's IAM Role ARN:", iam_role_arn)

# Initialize the Amazon Bedrock client in the region
bedrock = boto3.client('bedrock', region_name=region)
bedrock_client = boto3.client("bedrock-runtime", region_name=region)

## 2. Setup Amazon Bedrock Guardrails

Bedrock Guardrails enable us to define policies that restrict or modify model responses based on compliance, safety, and contextual relevance. In this section you will define policies to control and filter model responses for compliance and safety.

To implement guardrails in the RAG pipeline, we use Amazon Bedrock's API to programmatically define safety and compliance policies. Here's how these key functions work:
- `bedrock.create_guardrail`: Defines policies to filter inappropriate content and enforce compliance in model responses.
- `bedrock.create_guardrail_version`: Publishes a guardrail configuration for deployment.

Implementation Workflow
1. Create bedrock Guardrail: Define policies using create_guardrail
2. Version Management: Publish with create_guardrail_version
3. Attach the bedrock guardrail to the Inference step.

Together this configuration ensures all RAG inference interactions:
- Block requests for unverified treatments
- Anonymize patient identifiers
- Filter speculative medical claims
- Maintain audit trails for compliance

See AWS Bedrock Guardrails documentation for more details: https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-components.html 

### 2.1 Create a Guardrail
You will first create a AWS Bedrock guardrail with the following policies: 
- Topic-based restrictions: Block non-medical advice, misinformation, and unsupported cure claims.
- Content filtering: Block hate, insults, sexual content, violence, misconduct, and prompt injection. Supported strength levels: NONE|LOW|MEDIUM|HIGH
- Contextual grounding: Ensure answers are relevant to PubMed data.
- Word filtering: Block specific sensitive or misleading terms.
- Sensitive data anonymization: Automatically anonymize PII and sensitive medical data using entity types and regex patterns.

After you execute the code block in this section please note the printed Guardrail ID, ARN, and version for later use.

Note: To perform AWS Bedrock guardrail operations the IAM role will need the managed policy `arn:aws:iam::aws:policy/AmazonBedrockFullAccess`

In [None]:
import uuid

# Generate a unique client request token
client_request_token = str(uuid.uuid4())

# Create a Guardrail with specific filtering and compliance policies for medical use-case
response = bedrock.create_guardrail(
    name="MedicalContextGuardrails",
    description="Restrict responses to PubMed-based medical content only",
    blockedInputMessaging="This request cannot be processed due to safety protocols.",
    blockedOutputsMessaging="Response blocked per compliance guidelines.",

    # Topic-based restrictions (e.g., denying non-medical advice)
    topicPolicyConfig={
        'topicsConfig': [
            {'name': 'non-medical-advice', 'definition': 'Any recommendations outside medical expertise or context', 'type': 'DENY'},
            {'name': 'misinformation', 'definition': 'Dissemination of inaccurate or unverified medical information', 'type': 'DENY'},
            {'name': 'medical-cure-claims', 'definition': 'Claims of guaranteed or definitive cures for medical conditions without sufficient evidence', 'type': 'DENY'}
        ]
    },

    # Content filtering policies (e.g., blocking harmful or unethical content)
    contentPolicyConfig={
        'filtersConfig': [
            {'type': 'HATE', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'INSULTS', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'SEXUAL', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'VIOLENCE', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'MISCONDUCT', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'PROMPT_ATTACK', 'inputStrength': 'HIGH', 'outputStrength': 'NONE'}
        ]
    },

    # Contextual grounding policies ensuring relevance to PubMed-based embeddings
    contextualGroundingPolicyConfig={
        # Ensure responses are grounded in the embeddings loaded from PubMed articles
        'filtersConfig': [
            {'type': 'GROUNDING', 'threshold': 0.1},
            {'type': 'RELEVANCE', 'threshold': 0.1}
        ]
    },

    # List of restricted words related to sensitive medical topics
    wordPolicyConfig={
        # Example: blocking inappropriate usage of critical medical terms
        'wordsConfig': [
            {'text': "malpractice"}, {'text': "misdiagnosis"}, {'text': "unauthorized treatment"},
            {'text': "experimental drug"}, {'text': "unapproved therapy"}, {'text': "medical fraud"},
            {'text': "cure"}, {'text': "guaranteed cure"}, {'text': "permanent remission"}
        ]
    },

    # Sensitive data anonymization (e.g., patient information)
    sensitiveInformationPolicyConfig={
        # Anonymize identifiable patient information
        'piiEntitiesConfig': [
            {'type': "NAME", "action": "ANONYMIZE"}, {'type': "EMAIL", "action": "ANONYMIZE"},
            {'type': "PHONE", "action": "ANONYMIZE"}, {'type': "US_SOCIAL_SECURITY_NUMBER", "action": "ANONYMIZE"},
            {'type': "ADDRESS", "action": "ANONYMIZE"}, {'type': "CA_HEALTH_NUMBER", "action": "ANONYMIZE"},
            {'type': "PASSWORD", "action": "ANONYMIZE"}, {'type': "IP_ADDRESS", "action": "ANONYMIZE"},
            {'type': "CA_SOCIAL_INSURANCE_NUMBER", "action": "ANONYMIZE"}, {'type': "CREDIT_DEBIT_CARD_NUMBER", "action": "ANONYMIZE"},
            {'type': "AGE", "action": "ANONYMIZE"}, {'type': "US_BANK_ACCOUNT_NUMBER", "action": "ANONYMIZE"}
        ],
        # Example regex patterns for anonymizing sensitive medical data
        'regexesConfig': [
            {
                "name": "medical_procedure_code",
                "description": "Pattern for medical procedure codes",
                "pattern": "\\b[A-Z]{1,5}\\d{1,5}\\b",
                "action": "ANONYMIZE"
            },
            {
                "name": "clinical_trial_id",
                "description": "Pattern for clinical trial identifiers",
                "pattern": "\\bNCT\\d{8}\\b",
                "action": "ANONYMIZE"
            }
        ]
    },

    # Tags for environment tracking
    tags=[
        {"key": "Environment", "value": "Production"},
        {"key": "Department", "value": "Medical"}
    ],
    clientRequestToken=client_request_token
)

# Retrieve and print the Guardrail ID, ARN, and version
guardrail_id = response['guardrailId']
print(f"Guardrail ID: {guardrail_id}")
print(f"Guardrail ARN: {response['guardrailArn']}")
print(f"Version: {response['version']}")


### 2.2 Create a Published Version of the Guardrail
Now we will publish the created AWS Bedrock guadrail for use in inference. 
We will use the bedrock.create_guardrail_version command to publish your guardrail. 
Store the returned version identifier for use in later steps.

In [None]:
# First create a published version
version_response = bedrock.create_guardrail_version(
    guardrailIdentifier=response['guardrailId'],
    description="Production version 1.0"
)
guardrail_version=version_response['version']
guardrail_version

### 2.3 Define function to apply bedrock guardrail at inference
Using the created AWS bedrock guardrials we will now create a function to apply guardrails to model outputs. We will create the function `apply_output_guardrail` which calls Bedrock Guardrails on the generated text and returns the filtered output. You will use this function after generating model responses in later steps.

In [None]:
def apply_output_guardrail(output_text):
    """
    Applies guardrail policies to filter and sanitize the output text from LLM responses.

    This function processes the output text through defined guardrail policies to detect and
    handle sensitive information, ensuring compliance with security and privacy requirements.
    It can mask, anonymize, or block responses containing protected information like
    health insurance IDs, personal identifiers, or other sensitive data.

    Args:
        output_text (str): The raw output text from the LLM response to be processed.

    Returns:
        str: The sanitized output text with applied guardrail policies. If sensitive
            information is detected, it will be masked or anonymized according to the
            configured policies.

    Raises:
        GuardrailException: If there's an error in applying the guardrail policies.
        ValueError: If the input text is None or empty.
    
    """
    
    print(f"\nApply bedrock guardrails to the output using {guardrail_id} {guardrail_version}\n")
    
    try:
        # Use only the parameters supported by your boto3 version
        response = bedrock_client.apply_guardrail(
            guardrailIdentifier=guardrail_id,
            guardrailVersion=guardrail_version,
            source='OUTPUT',
            content=[
                {
                    'text': {
                        'text': output_text
                    }
                }
            ]
        )
        
        # Process response based on what fields are available
        if 'outputs' in response and response['outputs']:
            return response['outputs'][0]['text']
        else:
            return output_text
            
    except Exception as e:
        print(f"Warning: Output guardrail application failed: {str(e)}")
        return output_text

# 3.  Define SageMaker functions
In this section we will set up SageMaker endpoint integration for LLM inference. We use the similar approach for OpenSearch retrieval and SageMaker inference as defined in the previous lab.
- We begun by defining a custom ContentHandler for input/output formatting.
- Initialize the SagemakerEndpoint object with your endpoint name, region, and content handler.
- Prepare a prompt template for question-answering using context retrieved from OpenSearch.
- Test the setup by running a sample question and context through the model and printing the response.

In [None]:
import json
from typing import Dict

from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

prompt_template = """
        <|begin_of_text|>
        <|start_header_id|>system<|end_header_id|>
        You are an assistant for question-answering tasks. Answer the following question using the provided context. If you don't know the answer, just say "I don't know.".
        <|start_header_id|>user<|end_header_id|>
        Context: {context}
        
        Question: {question}
        <|start_header_id|>assistant<|end_header_id|> 
        Answer:"""

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["question"]
)


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        #print("Input prompt:", input_str)
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        # Parse and extract generated text from response
        response_json = json.loads(output.read().decode("utf-8"))
        #print("Raw response:", response_json)  # Debugging
        # Handle different response formats
        return response_json["generated_text"]


content_handler = ContentHandler()

sagemaker_llm=SagemakerEndpoint(
        endpoint_name=GENERATION_ENDPOINT_NAME,
        region_name=region,
        model_kwargs={"temperature": 1e-10, "max_new_tokens": 250},
        content_handler=content_handler,
    )


In [None]:
# Updated prompt and context for medical question

prompt = "What is the role of mitochondrial dynamics in programmed cell death in lace plants?"
context_prompt = """
Based on research into Aponogeton madagascariensis (lace plant), programmed cell death (PCD) occurs in the cells at the center of areoles in leaves.
The role of mitochondrial dynamics during this process is being investigated.
"""

input_prompt = PROMPT.format(question=prompt, context=context_prompt)

# Invoke the model using the prompt
response = sagemaker_llm(input_prompt)
print(response)


# 4. OpenSearch retrieval 


### 4.1 Define the OpenSearch vector database retrieval
Connect to OpenSearch and prepare for vector-based retrieval. The steps followed are,
- Set up OpenSearch authentication using your AWS credentials and the AWS4Auth class.
- Initialize the OpenSearch client and verify the connection by listing available indices.
- Initialize the sentence embedding model using HuggingFace Transformers (sentence-transformers/all-MiniLM-L6-v2).

Note: The IAM role used will need the managed permissions `arn:aws:iam::aws:policy/AmazonOpenSearchServiceFullAccess`

In [None]:
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth
import boto3
from transformers import pipeline
import json

In [None]:
# Connect to OpenSearch using the IAM Role of this notebook
credentials = boto3.Session().get_credentials()
signerauth = AWSV4SignerAuth(credentials, region, "es")

# Create OpenSearch client
aos_client = OpenSearch(
    hosts=[f"https://{AOS_HOST}"],
    http_auth=signerauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=60
)
print("Connection details: ")
aos_client

In [None]:
# Test opensearch connection by listing indices
try:
    response = aos_client.indices.get_alias("*")
    print("Indices:", response)
except Exception as e:
    print("Error connecting to OpenSearch:", str(e))

In [None]:
from typing import Any, Dict, List, Optional
from langchain_community.embeddings import SagemakerEndpointEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler

class EmbedContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        """
        Transforms the input into bytes that can be consumed by SageMaker endpoint.
        Args:
            inputs: List of input strings.
            model_kwargs: Additional keyword arguments to be passed to the endpoint.
        Returns:
            The transformed bytes input.
        """
        # Example: inference.py expects a JSON string with a "inputs" key:
        input_str = json.dumps({"inputs": inputs, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> List[List[float]]:
        """
        Transforms the bytes output from the endpoint into a list of embeddings.
        Args:
            output: The bytes output from SageMaker endpoint.
        Returns:
            The transformed output - list of embeddings
        Note:
            The length of the outer list is the number of input strings.
            The length of the inner lists is the embedding dimension.
        """
        # Example: inference.py returns a JSON string with the list of
        # embeddings in a "vectors" key:
        response_json = json.loads(output.read().decode("utf-8"))
        # print(len(response_json))
        return response_json


embed_content_handler = EmbedContentHandler()

In [None]:
def get_embedding(text, embed_endpoint_name, model_kwargs=None):
    """
    Call the SageMaker embedding model to embed the given text.
    Adjust the payload and response parsing according to your model's API.
    """
    embeddings = SagemakerEndpointEmbeddings(
        endpoint_name=embed_endpoint_name,
        region_name=region,
        content_handler=embed_content_handler,
    )

    return embeddings.embed_query(text)

In [None]:
from transformers import pipeline

# RAG retrieval function
def retrieve_context(query, k=3):
    query_embedding = get_embedding(query, EMBED_ENDPOINT_NAME)
    
    search_body = {
        "size": k,
        "query": {
            "knn": {
                "context_vector": {
                    "vector": query_embedding,
                    "k": k
                }
            }
        }
    }
    
    response = aos_client.search(
        index=OPENSEARCH_INDEX_NAME,
        body=search_body
    )
    
    return [hit["_source"]['contexts'] for hit in response["hits"]["hits"]]

In [None]:
# Retrieve relevant context
query = " what are the key components of phonological processing that are believed to influence \
reading levels in individuals who have undergone cerebral hemispherectomy procedure?"
contexts = retrieve_context(query)
print("Retrieved context:", contexts)

### 4.2 Retrieve context for input prompt
Retrieve relevant context for a query using vector search.
- Use the `retrieve_context` function to generate an embedding for the input query and perform a k-NN search in OpenSearch.
- The function returns the top-k most relevant context passages.
- Test the retrieval by running a sample query and printing the results.

In [None]:
def get_opensearch_vector_context(input_query):
    # Retrieve relevant context
    query = "What is Amazon?"
    contexts = retrieve_context(input_query)
    
    for context in contexts:
        print("\nRetrieved context: ", context)

    # Handle empty or single returned list
    if contexts:
        context_prompt = contexts[0] if len(contexts) > 0 else None
    else:
        context_prompt = None
    
    return context_prompt

# 4. Validate inference results with Bedrock guardrails
Now we use the inference and guardrails setup in the previous sections to demonstrate how guardrails block unsafe responses, block non-compliant responses, restrict anonymize information and block sensitive information for RAG inferences. The workflow followed is,
- Input a query that should trigger a guardrail (e.g., "Is there a guaranteed cure for Alzheimer's disease based on the latest research?").
- Retrieve relevant context using OpenSearch.
- Format the prompt and run inference through SageMaker.
- Print the initial (unfiltered) response.
- Apply the guardrail function and print the filtered response, which should display the preconfigured guardrail message.

## 4.1 Restrict reponse based on blocked content/topic

In this example you'll test the guardrail's ability to block responses related to cures for diseases without sufficient evidence.

When prompted  "Is there a cure for HIV based on the latest research?" the base model responds with "No", which is correct, but in this scenario the model shouldn't provide any affirmative answers related to cures per compliance guidlines. 

Therefore the guardrail should block the response and respond with "Response blocked per compliance guidelines.". You could additionally apply this guardrail on the input, blocking the request from ever making it to the generation model to begin with.

In [None]:
# Define the query that will be sent to the model
query = "Is there a cure for HIV based on the latest research?"

# Retrieve relevant context from the OpenSearch based on the query
# TODO Update the context retrieval function
context_prompt = get_opensearch_vector_context(query)  # Replace with your OpenSearch retrieval

input_prompt = PROMPT.format(question=query, context=context_prompt)

# Invoke the model using the prompt
raw_response = sagemaker_llm(input_prompt)
print("\n\n Initial response without guardrails: ", raw_response)

guardrail_response = apply_output_guardrail(raw_response)

# Print the user's query
print("\n\n User's query: ", query)  # Use the variable directly, not a dictionary

# Print the generated answer from the model based on the query and context
print("\n\n Answer with guardrails: ", guardrail_response)


## 4.2 Restrict reponse based on PII Data

This example shows a simple redaction of content coming back from the generation model.

In [None]:
# Define the query that will be sent to the model
query = "Can you provide the contact information, including the phone number and email address, for Dr. Vivek Murthy, who led the clinical trial NCT12345678?"

# Retrieve relevant context from the OpenSearch based on the query
# TODO Update the context retrieval function
context_prompt = get_opensearch_vector_context(query) # Replace with your OpenSearch retrieval

input_prompt = PROMPT.format(question=query, context=context_prompt)

# Invoke the model using the prompt
raw_response = sagemaker_llm(input_prompt)
print("\n\nInitial response without guardrails: ", raw_response)

guardrail_response = apply_output_guardrail(raw_response)

# Print the user's query
print("\n\nUser's query: ", query) # Use the variable directly, not a dictionary

# Print the generated answer from the model based on the query and context, with applying the guardrails
print("\n\nAnswer with guardrails: ", guardrail_response)

# Summary
By the end of this lab, you will have:
- Configured a secure, compliant RAG pipeline using AWS services.
- Learned how to enforce strict guardrails for safety, compliance, and relevance in medical AI applications.
- Practiced integrating vector retrieval, LLM inference, and response filtering in a real-world workflow.

`Next Steps`: Experiment with your own queries and adjust guardrail policies to fit other compliance or safety requirements in your domain.