# Bedrock Knowledge Base Retrieval and Generation with SageMaker Inference and Guardrails

## Description
This notebook demonstrates how to enhance a Retrieval-Augmented Generation (RAG) pipeline by integrating Amazon SageMaker Inference with Amazon Bedrock. We will walk through the process of querying a knowledge base, using SageMaker for model inference, applying Guardrails to control the generation of responses, and filtering results with metadata to ensure compliance and quality.


![Guardrails](./guardrail.png)

## 1. Import Required Functions

In [None]:
# Import necessary functions from advanced_rag_utils
from advanced_rag_utils import (
    load_variables,
    create_standard_filter,
    setup_bedrock_client,
    retrieve_from_bedrock_with_filter,
    format_llama3_prompt,
    generate_sagemaker_response,
    apply_output_guardrail,
    retrieve_generate_apply_guardrails
)

## 2. Load Configuration Variables

In [None]:
# Load configuration variables from a JSON file
variables = load_variables("../variables.json")
variables  # Display the loaded variables for confirmation

## 3. Set Up Required IDs and Configuration

In [None]:
# Knowledge Base Selection  
kb_id = variables["kbFixedChunk"]  # Options: "kbFixedChunk", "kbHierarchicalChunk", "kbSemanticChunk"

# Retrieve guardrail details
guardrail_id = variables["guardrail_id"]
guardrail_version = variables["guardrail_version"]

# SageMaker endpoint
sagemaker_endpoint = variables['sagemakerLLMEndpoint']

# Retrieval-Augmented Generation (RAG) Configuration  
number_of_results = 3  # Number of relevant documents to retrieve  
generation_configuration = {
    "temperature": 0,  # Lower temperature for more deterministic responses  
    "top_k": 10,  # Consider top 10 tokens at each generation step  
    "max_new_tokens": 5000,  # Maximum number of tokens to generate  
    "stop": "<|eot_id|>"  # Stop sequence to end the response generation  
}

## 4. Define Metadata Filter

In [None]:
# Create a standard filter for document type and year
metadata_filter = create_standard_filter('10K Report', 2023)

## 5. Initialize Bedrock Clients

In [None]:
# Initialize the Bedrock clients
bedrock_agent_client = setup_bedrock_client(variables["regionName"])
bedrock_runtime_client = boto3.client("bedrock-runtime", region_name=variables["regionName"])

## 6: Test Guardrails for Investment Advice
Let's ask the model for investment advice. When we created the guardrails, we restricted Bedrock from providing any investment advice. Bedrock should return a preconfigured response "This request cannot be processed due to safety protocols"

In [None]:
# Define the query for testing investment advice restriction
query = "based on your amazon's results should I buy amazon stock?"

In [None]:
# Use the comprehensive function to perform the RAG pipeline with guardrails
guardrail_response, raw_response, context = retrieve_generate_apply_guardrails(
    query=query,
    knowledge_base_id=kb_id,
    sagemaker_endpoint=sagemaker_endpoint,
    guardrail_id=guardrail_id,
    guardrail_version=guardrail_version,
    metadata_filter=metadata_filter,
    generation_config=generation_configuration,
    bedrock_agent_client=bedrock_agent_client,
    bedrock_runtime_client=bedrock_runtime_client,
    num_results=number_of_results,
    region_name=variables["regionName"]
)

# Print the query and response
print("Question:", {query})
# print(f"Context: {context}")  # Uncomment for debugging
print("\nRaw Response (Without Guardrails):")
print(raw_response)
print("\nGuardrail Response:")
print(guardrail_response)

## 7. Test Guardrails for PII Data

In [None]:
# Define a query for testing PII anonymization
query="Who is the current CFO of Amazon?"

In [None]:
# Use the comprehensive function to perform the RAG pipeline with guardrails
guardrail_response, raw_response, context = retrieve_generate_apply_guardrails(
    query=query,
    knowledge_base_id=kb_id,
    sagemaker_endpoint=sagemaker_endpoint,
    guardrail_id=guardrail_id,
    guardrail_version=guardrail_version,
    metadata_filter=metadata_filter,
    generation_config=generation_configuration,
    bedrock_agent_client=bedrock_agent_client,
    bedrock_runtime_client=bedrock_runtime_client,
    num_results=number_of_results,
    region_name=variables["regionName"]
)

# Print the query and response
print("Question:", {query})
# print(f"Context: {context}")  # Uncomment for debugging
print("\nRaw Response (Without Guardrails):")
print(raw_response)
print("\nGuardrail Response (With PII Anonymization):")
print(guardrail_response)

## 8. (Optional) Step-by-Step Approach

In [None]:
# If you prefer to execute the steps individually:

# 1. Retrieve context from Bedrock KB with metadata filtering
context = retrieve_from_bedrock_with_filter(
    query=query,
    knowledge_base_id=kb_id,
    metadata_filter=metadata_filter,
    bedrock_client=bedrock_agent_client,
    num_results=number_of_results,
    region_name=variables["regionName"]
)

# 2. Format prompt using retrieved context
prompt = format_llama3_prompt(query, context)

# 3. Generate response using SageMaker endpoint
raw_response = generate_sagemaker_response(
    prompt=prompt,
    endpoint_name=sagemaker_endpoint,
    generation_config=generation_configuration
)

# 4. Apply guardrails to the output
guardrail_response = apply_output_guardrail(
    output_text=raw_response,
    guardrail_id=guardrail_id,
    guardrail_version=guardrail_version,
    bedrock_client=bedrock_runtime_client,
    region_name=variables["regionName"]
)

# 5. Display results
print("Question:", {query})
# print(f"Context: {context}")  # Uncomment for debugging
print("\nRaw Response (Without Guardrails):")
print(raw_response)
print("\nGuardrail Response:")
print(guardrail_response)