# Sagemaker RAG retrieval and generation with SageMaker Inference and Bedrock Guardrails

This lab demonstrates how to enhance a Retrieval-Augmented Generation (RAG) pipeline by integrating Amazon SageMaker Inference with Amazon Bedrock Guardrails. We will walk through the process of querying a OpenSearch vector knowledge base, using SageMaker for model inference, applying Guardrails to control the generation of responses, and filtering results with metadata to ensure compliance and quality. We will use the same PubMed medical theme generated in the opensearch RAG lab where we will refer to the previously created opensearch vector database with PubMed dataset and show how guardrails can be used to filter the response.

## Pre-requisites

In [None]:
- Before proceeding with this notebook, you should complete all of the previous labs.
- Required Python libraries: opensearch-py, langchain, boto3, requests_aws4auth
- Access to Amazon SageMaker and OpenSearch
- Appropriate IAM roles and permissions

## Install libraries 

In [None]:
%pip install langchain boto3 -q
%pip install opensearch-py
%pip install requests-aws4auth
%pip install certifi
print("Installs completed.")

In [1]:
import boto3
import boto3
import json
from botocore.exceptions import ClientError, BotoCoreError

## 1. Set configuration variables

In [None]:
# Opensearch Configuration
aos_host = "<ENTER_VALUE>" # replace with the opensearch Domain endpoint (IPv4) from the previous lab (Without https://)
OPENSEARCH_URL = f"https://{aos_host}"
service = "es"  
port = 443 # Enter the opensearch 
index_name = "<ENTER_VALUE>"  # Enter the opensearch index used in the ingestion from the previous lab

# Sagemaker configuration
SAGEMAKER_LLM_ENDPOINT_NAME = "<ENTER_VALUE>"  # Enter the sagemaker LLM endpoint created in the previous lab
session = boto3.Session()
sts_client = boto3.client('sts')
# Get caller identity
caller_identity = sts_client.get_caller_identity()

# Extract and print the IAM role ARN
iam_role_arn = caller_identity["Arn"]
account_id = sts_client.get_caller_identity().get('Account')
region = session.region_name

print("Session's IAM Role ARN:", iam_role_arn)

# Initialize the Amazon Bedrock client in the region
bedrock = boto3.client('bedrock', region_name=region)
bedrock_client = boto3.client("bedrock-runtime", region_name=region)

Session's IAM Role ARN: arn:aws:sts::329599663853:assumed-role/AmazonSageMaker-ExecutionRole-20250203T171097/SageMaker


## 2. Setup Amazon Bedrock Guardrails

### 2.1 Create a Guardrail

Bedrock Guardrails enable us to define policies that restrict or modify model responses based on compliance, safety, and contextual relevance. 

See AWS Bedrock Guardrails documentation for more details: https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-components.html 

Note: Update the IAM role by adding the managed policy `arn:aws:iam::aws:policy/AmazonBedrockFullAccess`

In [4]:
import uuid

# Generate a unique client request token
client_request_token = str(uuid.uuid4())

# Create a Guardrail with specific filtering and compliance policies for medical use-case
response = bedrock.create_guardrail(
    name="MedicalContextGuardrails",
    description="Restrict responses to PubMed-based medical content only",
    blockedInputMessaging="This request cannot be processed due to safety protocols.",
    blockedOutputsMessaging="Response modified per compliance guidelines.",

    # Topic-based restrictions (e.g., denying non-medical advice)
    topicPolicyConfig={
        'topicsConfig': [
            {'name': 'non-medical-advice', 'definition': 'Any recommendations outside medical expertise or context', 'type': 'DENY'},
            {'name': 'misinformation', 'definition': 'Dissemination of inaccurate or unverified medical information', 'type': 'DENY'},
            {'name': 'medical-cure-claims', 'definition': 'Claims of guaranteed or definitive cures for medical conditions without sufficient evidence', 'type': 'DENY'}
        ]
    },

    # Content filtering policies (e.g., blocking harmful or unethical content)
    contentPolicyConfig={
        'filtersConfig': [
            {'type': 'HATE', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'INSULTS', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'SEXUAL', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'VIOLENCE', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'MISCONDUCT', 'inputStrength': 'HIGH', 'outputStrength': 'HIGH'},
            {'type': 'PROMPT_ATTACK', 'inputStrength': 'HIGH', 'outputStrength': 'NONE'}
        ]
    },

    # Contextual grounding policies ensuring relevance to PubMed-based embeddings
    contextualGroundingPolicyConfig={
        # Ensure responses are grounded in the embeddings loaded from PubMed articles
        'filtersConfig': [
            {'type': 'GROUNDING', 'threshold': 0.1},
            {'type': 'RELEVANCE', 'threshold': 0.1}
        ]
    },

    # List of restricted words related to sensitive medical topics
    wordPolicyConfig={
        # Example: blocking inappropriate usage of critical medical terms
        'wordsConfig': [
            {'text': "malpractice"}, {'text': "misdiagnosis"}, {'text': "unauthorized treatment"},
            {'text': "experimental drug"}, {'text': "unapproved therapy"}, {'text': "medical fraud"},
            {'text': "cure"}, {'text': "guaranteed cure"}, {'text': "permanent remission"}
        ]
    },

    # Sensitive data anonymization (e.g., patient information)
    sensitiveInformationPolicyConfig={
        # Anonymize identifiable patient information
        'piiEntitiesConfig': [
            {'type': "NAME", "action": "ANONYMIZE"}, {'type': "EMAIL", "action": "ANONYMIZE"},
            {'type': "PHONE", "action": "ANONYMIZE"}, {'type': "US_SOCIAL_SECURITY_NUMBER", "action": "ANONYMIZE"},
            {'type': "ADDRESS", "action": "ANONYMIZE"}, {'type': "CA_HEALTH_NUMBER", "action": "ANONYMIZE"},
            {'type': "PASSWORD", "action": "ANONYMIZE"}, {'type': "IP_ADDRESS", "action": "ANONYMIZE"},
            {'type': "CA_SOCIAL_INSURANCE_NUMBER", "action": "ANONYMIZE"}, {'type': "CREDIT_DEBIT_CARD_NUMBER", "action": "ANONYMIZE"},
            {'type': "AGE", "action": "ANONYMIZE"}, {'type': "US_BANK_ACCOUNT_NUMBER", "action": "ANONYMIZE"}
        ],
        # Example regex patterns for anonymizing sensitive medical data
        'regexesConfig': [
            {
                "name": "medical_procedure_code",
                "description": "Pattern for medical procedure codes",
                "pattern": "\\b[A-Z]{1,5}\\d{1,5}\\b",
                "action": "ANONYMIZE"
            },
            {
                "name": "clinical_trial_id",
                "description": "Pattern for clinical trial identifiers",
                "pattern": "\\bNCT\\d{8}\\b",
                "action": "ANONYMIZE"
            }
        ]
    },

    # Tags for environment tracking
    tags=[
        {"key": "Environment", "value": "Production"},
        {"key": "Department", "value": "Medical"}
    ],
    clientRequestToken=client_request_token
)

# Retrieve and print the Guardrail ID, ARN, and version
guardrail_id = response['guardrailId']
print(f"Guardrail ID: {guardrail_id}")
print(f"Guardrail ARN: {response['guardrailArn']}")
print(f"Version: {response['version']}")


Guardrail ID: cb7wc871en5w
Guardrail ARN: arn:aws:bedrock:us-west-2:329599663853:guardrail/cb7wc871en5w
Version: DRAFT


### 2.2 Create a Published Version of the Guardrail

In [5]:
# First create a published version
version_response = bedrock.create_guardrail_version(
    guardrailIdentifier=response['guardrailId'],
    description="Production version 1.0"
)
guardrail_version=version_response['version']
guardrail_version

'1'

### 2.3 Define function to apply bedrock guardrail at inference

In [8]:
def apply_output_guardrail(output_text):
    """Apply guardrails to the output after generation"""
    print("Apply bedrock guardrails to the input using", guardrail_id ,guardrail_version)
    try:
        # Use only the parameters supported by your boto3 version
        response = bedrock_client.apply_guardrail(
            guardrailIdentifier=guardrail_id,
            guardrailVersion=guardrail_version,
            source='OUTPUT',
            content=[
                {
                    'text': {
                        'text': output_text
                    }
                }
            ]
        )
        
        # Process response based on what fields are available
        if 'outputs' in response and response['outputs']:
            return response['outputs'][0]['text']
        else:
            return output_text
            
    except Exception as e:
        print(f"Warning: Output guardrail application failed: {str(e)}")
        return output_text

# 3.  Define SageMaker functions
We use the similar approach for OpenSearch retrieval and SageMaker inference as defined in the previous lab

In [11]:
import json
from typing import Dict

from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_core.prompts import PromptTemplate

prompt_template = """Answer the question using the provided context.
Question: {question}
Context: {context}
Answer:"""

PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["question"]
)


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        print("Input prompt:", input_str)
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        # Parse and extract generated text from response
        response_json = json.loads(output.read().decode("utf-8"))
        print("Raw response:", response_json)  # Debugging
        # Handle different response formats
        return response_json["generated_text"]


content_handler = ContentHandler()

sagemaker_llm=SagemakerEndpoint(
        endpoint_name=SAGEMAKER_LLM_ENDPOINT_NAME,
        region_name=region,
        model_kwargs={"temperature": 1e-10},
        content_handler=content_handler,
    )


In [12]:
# Updated prompt and context for medical question

prompt = "What is the role of mitochondrial dynamics in programmed cell death in lace plants?"
context_prompt = """
Based on research into Aponogeton madagascariensis (lace plant), programmed cell death (PCD) occurs in the cells at the center of areoles in leaves.
The role of mitochondrial dynamics during this process is being investigated.
"""

input_prompt = PROMPT.format(question=prompt, context=context_prompt)

# Invoke the model using the prompt
response = sagemaker_llm(input_prompt)
print(response)


Input prompt: {"inputs": "Answer the question using the provided context.\nQuestion: What is the role of mitochondrial dynamics in programmed cell death in lace plants?\nContext: \nBased on research into Aponogeton madagascariensis (lace plant), programmed cell death (PCD) occurs in the cells at the center of areoles in leaves.\nThe role of mitochondrial dynamics during this process is being investigated.\n\nAnswer:", "parameters": {"temperature": 1e-10}}
Raw response: {'generated_text': ' The role of mitochondrial dynamics in programmed cell death in lace plants is currently under investigation. Research is focused on understanding how changes in mitochondrial shape and function contribute'}
 The role of mitochondrial dynamics in programmed cell death in lace plants is currently under investigation. Research is focused on understanding how changes in mitochondrial shape and function contribute


# 4. OpenSearch retrieval 


### 4.1 Define the OpenSearch vector database retrieval
Note: The IAM role used will need the managed permissions `arn:aws:iam::aws:policy/AmazonOpenSearchServiceFullAccess`

In [13]:
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
import boto3
from transformers import pipeline
import json

2025-03-31 17:45:08.962730: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
# OpenSearch Authentication setup
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(
    credentials.access_key,
    credentials.secret_key,
    region,
    service,
    session_token=credentials.token
)

# Initialize OpenSearch client
opensearch_client = OpenSearch(
    hosts=[{'host': aos_host, 'port': port}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection
)


In [15]:
# Test opensearch connection by listing indices
try:
    response = opensearch_client.indices.get_alias("*")
    print("Indices:", response)
except Exception as e:
    print("Error connecting to OpenSearch:", str(e))

Indices: {'.plugins-ml-model-group': {'aliases': {}}, 'gtm-rag-index': {'aliases': {}}, '.plugins-flow-framework-state': {'aliases': {}}, '.ql-datasources': {'aliases': {}}, '.plugins-ml-agent': {'aliases': {}}, '.plugins-flow-framework-templates': {'aliases': {}}, '.plugins-ml-task': {'aliases': {}}, '.kibana_1': {'aliases': {'.kibana': {}}}, '.opendistro_security': {'aliases': {}}, '.plugins-ml-config': {'aliases': {}}, '.opensearch-observability': {'aliases': {}}, '.plugins-ml-model': {'aliases': {}}, '.opensearch-sap-log-types-config': {'aliases': {}}, '.plugins-flow-framework-config': {'aliases': {}}}


In [16]:
# Initialize embedding model
embedder = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")

Device set to use cpu


In [17]:
from transformers import pipeline

# RAG retrieval function
def retrieve_context(query, k=3):
    query_embedding = embedder(query)[0][0]
    
    search_body = {
        "size": k,
        "query": {
            "knn": {
                "vector": {
                    "vector": query_embedding,
                    "k": k
                }
            }
        }
    }
    
    response = opensearch_client.search(
        index=index_name,
        body=search_body
    )
    
    return [hit["_source"]["text"] for hit in response["hits"]["hits"]]

In [18]:
# Retrieve relevant context
query = " what are the key components of phonological processing that are believed to influence \
reading levels in individuals who have undergone cerebral hemispherectomy procedure?"
contexts = retrieve_context(query)
print("Retrieved context:", contexts)

Retrieved context: ['ID: 25819796 - Cerebral hemispherectomy, a surgical procedure undergone to control intractable seizures, is becoming a standard procedure with more cases identified and treated early in life [33]. While the effect of the dominant hemisphere resection on spoken language has been extensively researched, little is known about reading abilities in individuals after left-sided resection. Left-lateralized phonological abilities are the key components of reading, i.e., grapheme-phoneme conversion skills [1]. These skills are critical for the acquisition of word-specific orthographic knowledge and have been shown to predict reading levels in average readers as well as in readers with mild cognitive disability [26]. Furthermore, impaired phonological processing has been implicated as the cognitive basis in struggling readers. Here, we explored the reading skills in participants who have undergone left cerebral hemispherectomy.', 'ID: 25819796 - Seven individuals who have un

### 4.2 Retrieve context for input prompt

In [19]:
def get_opensearch_vector_context(input_query):
    # Retrieve relevant context
    query = "What is Amazon?"
    contexts = retrieve_context(input_query)
    print("Retrieved context: ", contexts)

    # Handle empty or single returned list
    if contexts:
        context_prompt = contexts[0] if len(contexts) > 0 else None
    else:
        context_prompt = None
    
    return context_prompt

# 4. Validate inference results with Bedrock guardrails

## 4.1 Restrict reponse based on blocked content/topic
Lets ask the Foundational model for an investment advice. When we created the guardrails, we restricted bedrock to provide any investment advice. Bedrock should be return a preconfigured response "This request cannot be processed due to safety protocols"

In [20]:
# Define the query that will be sent to the model
query = "Is there a guaranteed cure for Alzheimer's disease based on the latest research?"

# Retrieve relevant context from the OpenSearch based on the query
# TODO Update the context retrieval function
context_prompt = get_opensearch_vector_context(query)  # Replace with your OpenSearch retrieval

input_prompt = PROMPT.format(question=query, context=context_prompt)

# Invoke the model using the prompt
raw_response = sagemaker_llm(input_prompt)
print("\n Initial response without guardrails: ", raw_response)

guardrail_response = apply_output_guardrail(raw_response)

# Print the user's query
print("\n User's query: ", query)  # Use the variable directly, not a dictionary

# Print the generated answer from the model based on the query and context
print("\n Answer with guardrails: ", guardrail_response)


Retrieved context:  ["ID: 9444542 - To investigate whether the presence of hippocampal atrophy (HCA) on MRI in Alzheimer's disease (AD) leads to a more rapid decline in cognitive function. To investigate whether cognitively unimpaired controls and depressed subjects with HCA are at higher risk than those without HCA of developing dementia.", 'ID: 18378554 - Ambulatory residents meeting DSM-IV criteria for dementia (N = 181) were studied.', 'ID: 9444542 - No significant differences in rate of cognitive decline, mortality or progression to dementia were found between subjects with or without HCA.']
Input prompt: {"inputs": "Answer the question using the provided context.\nQuestion: Is there a guaranteed cure for Alzheimer's disease based on the latest research?\nContext: ID: 9444542 - To investigate whether the presence of hippocampal atrophy (HCA) on MRI in Alzheimer's disease (AD) leads to a more rapid decline in cognitive function. To investigate whether cognitively unimpaired control

## 4.2 Restrict reponse based on PII Data

In [22]:
# Define the query that will be sent to the model
query = "Can you provide the contact information, including the phone number and email address, for Dr. Vivek Murthy, \
who led the clinical trial NCT12345678?"

# Retrieve relevant context from the OpenSearch based on the query
# TODO Update the context retrieval function
context_prompt = get_opensearch_vector_context(query) # Replace with your OpenSearch retrieval

input_prompt = PROMPT.format(question=query, context=context_prompt)

# Invoke the model using the prompt
raw_response = sagemaker_llm(input_prompt)
print("\n Initial response without guardrails: ", raw_response)

guardrail_response = apply_output_guardrail(raw_response)

# Print the user's query
print("\n User's query: ", query) # Use the variable directly, not a dictionary

# Print the generated answer from the model based on the query and context, with applying the guardrails
print("\n Answer with guardrails: ", guardrail_response)

Retrieved context:  ['ID: 23949294 - Of 619 patients on follow-up, 82 (13.2%) were diagnosed in the pre-HAART era. At the time of our study, 79 (96.3%) patients were on HAART, with a median duration of 14 years (IQR 12-15) of therapy, and exposure to mono or dual nucleoside reverse transcriptase inhibitors regimens in 47.8% of cases.\xa0Sixty-nine patients (87.3%) had undetectable VL, 37 (46.8%) never presented virologic failure, and 19 (24.1%) experienced only one failure. Thirteen patients (16.5%) were receiving third-line ART regimens, with an average of 2.7-fold more virologic failures than those on first- or second-line regimens (p = 0.007).', "ID: 21252642 - It was hypothesized that patients' satisfaction with information regarding clinical trials would improve after targeted educational interventions, and accruals to clinical trials would increase in the year following those interventions.", 'ID: 21252642 - Patient satisfaction with information significantly increased after the 