## Contextual retrieval using custom chunking provided by Amazon Bedrock Knowledge Bases

In this notebook, we will create 2 knowledge bases to provide sample code for the following chunking options supported by Amazon Bedrock Knowledge Bases: 
- Fixed Chunking
- Contextual retrieval using custom chunking with Lambda function 

We will use a synthetic 10K report as data for a fiticious company called `Octank Financial` to demo the solution.
After creating knowledge bases we will evaluate the results on the same dataset. The focus will be on improving the quality of search results which in turn will improve the accuracy of responses generated by the foundation model. 

## 1. Import the needed libraries
First step is to install the pre-requisites packages.

In [None]:
%pip install --upgrade pip --quiet
%pip install -r ../requirements.txt --no-deps --quiet
%pip install -r ../requirements.txt --upgrade --quiet
%pip install --upgrade ragas --quiet

In [None]:
# restart kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")

In [None]:
import botocore
botocore.__version__

In [None]:
import os
import sys
import time
import boto3
import logging
import pprint
import json

# Set the path to import module
from pathlib import Path
current_path = Path().resolve()
current_path = current_path.parent
if str(current_path) not in sys.path:
    sys.path.append(str(current_path))
# Print sys.path to verify
# print(sys.path)

from utils.knowledge_base import BedrockKnowledgeBase

In [None]:
#Clients
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
session = boto3.session.Session()
region =  session.region_name
account_id = sts_client.get_caller_identity()["Account"]
bedrock_agent_client = boto3.client('bedrock-agent')
bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime') 
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
region, account_id

In [None]:
import time

# Get the current timestamp
current_time = time.time()

# Format the timestamp as a string
timestamp_str = time.strftime("%Y%m%d%H%M%S", time.localtime(current_time))[-7:]
# Create the suffix using the timestamp
suffix = f"{timestamp_str}"
knowledge_base_name_standard = 'standard-kb'
knowledge_base_name_custom = 'custom-chunking-kb'
knowledge_base_description = "Knowledge Base containing complex PDF."
bucket_name = f'{knowledge_base_name_standard}-{suffix}'
intermediate_bucket_name = f'{knowledge_base_name_standard}-intermediate-{suffix}'
lambda_function_name = f'{knowledge_base_name_custom}-lambda-{suffix}'
foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"

# Define data sources
data_source=[{"type": "S3", "bucket_name": bucket_name}]

## 2 - Create knowledge bases with fixed chunking strategy
Let's start by creating a [Amazon Bedrock Knowledge Bases](https://aws.amazon.com/bedrock/knowledge-bases/) to store the restaurant menus. Knowledge Bases allow you to integrate with different vector databases including [Amazon OpenSearch Serverless](https://aws.amazon.com/opensearch-service/features/serverless/), [Amazon Aurora](https://aws.amazon.com/rds/aurora/), [Pinecone](http://app.pinecone.io/bedrock-integration), [Redis Enterprise]() and [MongoDB Atlas](). For this example, we will integrate the knowledge base with Amazon OpenSearch Serverless. To do so, we will use the helper class `BedrockKnowledgeBase` which will create the knowledge base and all of its pre-requisites:
1. IAM roles and policies
2. S3 bucket
3. Amazon OpenSearch Serverless encryption, network and data access policies
4. Amazon OpenSearch Serverless collection
5. Amazon OpenSearch Serverless vector index
6. Knowledge base
7. Knowledge base data source

First we will create a knowledge base using fixed chunking strategy followed by hierarchical chunking strategy. 

Parameter values: 
```
"chunkingStrategy": "FIXED_SIZE | NONE 
```

knowledge_base_standard = BedrockKnowledgeBase(
    kb_name=f'{knowledge_base_name_standard}-{suffix}',
    kb_description=knowledge_base_description,
    data_sources=data_source,
    chunking_strategy = "FIXED_SIZE", 
    suffix = f'{suffix}-f'
)

In [None]:
knowledge_base_standard = BedrockKnowledgeBase(
    kb_name=f'{knowledge_base_name_standard}-{suffix}',
    kb_description=knowledge_base_description,
    data_sources=data_source,
    chunking_strategy = "FIXED_SIZE", 
    suffix = f'{suffix}-f'
)

## 2.1 Upload the dataset to Amazon S3
Now that we have created the knowledge base, let's populate it with the `Octank financial 10K` report dataset. The Knowledge Base data source expects the data to be available on the S3 bucket connected to it and changes on the data can be syncronized to the knowledge base using the `StartIngestionJob` API call. In this example we will use the [boto3 abstraction](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent/client/start_ingestion_job.html) of the API, via our helper classe. 

Let's first upload the menu's data available on the `dataset` folder to s3.

In [None]:
import os

def upload_directory(path, bucket_name):
    for root, dirs, files in os.walk(path):
        for file in files:
            file_to_upload = os.path.join(root, file)
            if file not in ["LICENSE", "NOTICE", "README.md"]:
                print(f"uploading file {file_to_upload} to {bucket_name}")
                s3_client.upload_file(file_to_upload, bucket_name, file)
            else:
                print(f"Skipping file {file_to_upload}")

upload_directory("../synthetic_dataset", bucket_name)


Now we start the ingestion job.

In [None]:
# ensure that the kb is available
time.sleep(30)
# sync knowledge base
knowledge_base_standard.start_ingestion_job()

Finally we save the Knowledge Base Id to test the solution at a later stage. 

In [None]:
kb_id_standard = knowledge_base_standard.get_knowledge_base_id()

### 2.2 Test the Knowledge Base
Now the Knowlegde Base is available we can test it out using the [**retrieve**](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve.html) and [**retrieve_and_generate**](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html) functions. 

#### Testing Knowledge Base with Retrieve and Generate API

Let's first test the knowledge base using the retrieve and generate API. With this API, Bedrock takes care of retrieving the necessary references from the knowledge base and generating the final answer using a foundation model from Bedrock.

query = `Provide a summary of consolidated statements of cash flows of Octank Financial for the fiscal years ended December 31, 2019.`

The right response for this query as per ground truth QA pair is: 

```
The cash flow statement for Octank Financial in the year ended December 31, 2019 reveals the following:
- Cash generated from operating activities amounted to $710 million, which can be attributed to a $700 million profit and non-cash charges such as depreciation and amortization.
- Cash outflow from investing activities totaled $240 million, with major expenditures being the acquisition of property, plant, and equipment ($200 million) and marketable securities ($60 million), partially offset by the sale of property, plant, and equipment ($40 million) and maturing marketable securities ($20 million).
- Financing activities resulted in a cash inflow of $350 million, stemming from the issuance of common stock ($200 million) and long-term debt ($300 million), while common stock repurchases ($50 million) and long-term debt payments ($100 million) reduced the cash inflow. 
Overall, Octank Financial experienced a net cash enhancement of $120 million in 2019, bringing their total cash and cash equivalents to $210 million.
```

In [None]:
query = "Provide a summary of consolidated statements of cash flows of Octank Financial for the fiscal years ended December 31, 2019."

In [None]:
time.sleep(20)
response = bedrock_agent_runtime_client.retrieve_and_generate(
    input={
        "text": query
    },
    retrieveAndGenerateConfiguration={
        "type": "KNOWLEDGE_BASE",
        "knowledgeBaseConfiguration": {
            'knowledgeBaseId': kb_id_standard,
            "modelArn": "arn:aws:bedrock:{}::foundation-model/{}".format(region, foundation_model),
            "retrievalConfiguration": {
                "vectorSearchConfiguration": {
                    "numberOfResults":5
                } 
            }
        }
    }
)

print(response['output']['text'],end='\n'*2)

As you can see, with the retrieve and generate API we get the final response directly, now let's observe the citations for `RetreiveAndGenerate` API. Since, our primary focus on this notebook is to observe the retrieved chunks and citations returned by the model while generating the response. When we provide the relevant context to the foundation model alongwith the query, it will most likely generate the high quality response. 

In [None]:
def citations_rag_print(response_ret):
#structure 'retrievalResults': list of contents. Each list has content, location, score, metadata
    for num,chunk in enumerate(response_ret,1):
        print(f'Chunk {num}: ',chunk['content']['text'],end='\n'*2)
        print(f'Chunk {num} Location: ',chunk['location'],end='\n'*2)
        print(f'Chunk {num} Metadata: ',chunk['metadata'],end='\n'*2)

In [None]:
response_standard = response['citations'][0]['retrievedReferences']
print("# of citations or chunks used to generate the response: ", len(response_standard))
citations_rag_print(response_standard)

Let's now inspect the source information from the knowledge base with the retrieve API.

#### Testing Knowledge Base with Retrieve API
If you need an extra layer of control, you can retrieve the chunks that best match your query using the retrieve API. In this setup, we can configure the desired number of results and control the final answer with your own application logic. The API then provides you with the matching content, its S3 location, the similarity score and the chunk metadata.

In [None]:
def response_print(response_ret):
#structure 'retrievalResults': list of contents. Each list has content, location, score, metadata
    for num,chunk in enumerate(response_ret['retrievalResults'],1):
        print(f'Chunk {num}: ',chunk['content']['text'],end='\n'*2)
        print(f'Chunk {num} Location: ',chunk['location'],end='\n'*2)
        print(f'Chunk {num} Score: ',chunk['score'],end='\n'*2)
        print(f'Chunk {num} Metadata: ',chunk['metadata'],end='\n'*2)


In [None]:
response_standard_ret = bedrock_agent_runtime_client.retrieve(
    knowledgeBaseId=kb_id_standard, 
    nextToken='string',
    retrievalConfiguration={
        "vectorSearchConfiguration": {
            "numberOfResults":5,
        } 
    },
    retrievalQuery={
        'text': query
    }
)

print("# of retrieved results: ", len(response_standard_ret['retrievalResults']))
response_print(response_standard_ret)

As you can notice, that with `fixed chunking` we get 5 retrieved results as requested in the API using `semantic similarity` which is the default for `Retrieve API`. Let's now use `Custom chunking` strategy for Contextual retrieval and inspect the retrieved results using `RetrieveAndGenerate` API as well as `Retrieve` API. 

## 2. Contextual retrieval using custom chunking option
When creating an Knowledge Bases (KB) for Amazon Bedrock, you can connect a Lambda function to specify your custom chunking logic. During ingestion, if lambda function is provided, Knowledge Bases, will run the lambda function, and store the input and output values in the intermediate s3 bucket provided.

> <br>
> Note: Lambda function with KB can be used for adding custom chunking logic as well processing your chunks for example, adding chunk level metadata. In this example we are focusing on using Lambda function for custom chunking logic.
> <br></br>

### 2.1 Create the Lambda Function

We will now create a lambda function which will have code for custom chunking. To do so we will:

1. Create the `lambda_function.py` file which contains the logic for custom chunking.
2. Create the IAM role for our Lambda function.
3. Create the lambda function with the required permissions.

#### Create the function code
 Let's create the lambda function tha implements the functions for `reading your file from intermediate bucket`, `process the contents with custom chunking logic` and `write the output back to s3 bucket`. 

In [None]:
%%writefile lambda_function.py
import json
import boto3
import os
import logging
import traceback
from botocore.exceptions import ClientError

# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# Constants for chunking
MAX_TOKENS = 1000
OVERLAP_PERCENTAGE = 0.20

def estimate_tokens(text):
    """
    Rough estimation of tokens (approximation: 1 token ≈ 0.75 words)
    """
    words = text.split()
    return int(len(words) * 0.75)

def chunk_text(text, max_tokens=MAX_TOKENS, overlap_percentage=OVERLAP_PERCENTAGE):
    """
    Chunk text based on words with specified max tokens and overlap
    """
    if not text:
        return []

    # Split text into words
    words = text.split()
    if not words:
        return []

    # Estimate words per chunk based on token limit
    # Assuming average of 0.75 tokens per word
    words_per_chunk = int(max_tokens * 1.33)  # Convert tokens to approximate words
    overlap_words = int(words_per_chunk * overlap_percentage)

    chunks = []
    current_position = 0
    total_words = len(words)

    while current_position < total_words:
        # Calculate end position for current chunk
        chunk_end = min(current_position + words_per_chunk, total_words)
        
        # Get current chunk words
        chunk_words = words[current_position:chunk_end]
        
        # If this isn't the last chunk, try to find a good break point
        if chunk_end < total_words:
            # Look for sentence-ending punctuation in the last few words
            for i in range(len(chunk_words) - 1, max(len(chunk_words) - 10, 0), -1):
                if chunk_words[i].endswith(('.', '!', '?')):
                    chunk_end = current_position + i + 1
                    chunk_words = chunk_words[:i + 1]
                    break

        # Join words back into text
        chunk_text = ' '.join(chunk_words)
        chunks.append(chunk_text.strip())
        
        # Move position considering overlap
        current_position = chunk_end - overlap_words if chunk_end < total_words else chunk_end

    return chunks

def write_output_to_s3(s3_client, bucket_name, file_name, json_data):
    """
    Write JSON data to S3 bucket
    """
    try:
        json_string = json.dumps(json_data)
        response = s3_client.put_object(
            Bucket=bucket_name,
            Key=file_name,
            Body=json_string,
            ContentType='application/json'
        )

        if response['ResponseMetadata']['HTTPStatusCode'] == 200:
            print(f"Successfully uploaded {file_name} to {bucket_name}")
            return True
        else:
            print(f"Failed to upload {file_name} to {bucket_name}")
            return False

    except ClientError as e:
        print(f"Error occurred: {e}")
        return False

def read_from_s3(s3_client, bucket_name, file_name):
    """
    Read JSON data from S3 bucket
    """
    try:
        response = s3_client.get_object(Bucket=bucket_name, Key=file_name)
        return json.loads(response['Body'].read().decode('utf-8'))
    except ClientError as e:
        print(f"Error reading file from S3: {str(e)}")

def parse_s3_path(s3_path):
    """
    Parse S3 path into bucket and key
    """
    s3_path = s3_path.replace('s3://', '')
    parts = s3_path.split('/', 1)
    if len(parts) != 2:
        raise ValueError("Invalid S3 path format")
    return parts[0], parts[1]

def invoke_model_with_response_stream(bedrock_runtime, prompt, max_tokens=1000):
    """
    Invoke Bedrock model with streaming response
    """
    model_id = 'anthropic.claude-3-haiku-20240307-v1:0'
    request_body = json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": max_tokens,
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "temperature": 0.0,
    })

    try:
        response = bedrock_runtime.invoke_model_with_response_stream(
            modelId=model_id,
            contentType='application/json',
            accept='application/json',
            body=request_body
        )

        for event in response.get('body'):
            chunk = json.loads(event['chunk']['bytes'].decode())
            if chunk['type'] == 'content_block_delta':
                yield chunk['delta']['text']
            elif chunk['type'] == 'message_delta':
                if 'stop_reason' in chunk['delta']:
                    break

    except ClientError as e:
        print(f"An error occurred: {e}")
        yield None

# Define the contextual retrieval prompt
contextual_retrieval_prompt = """
    <document>
    {doc_content}
    </document>

    Here is the chunk we want to situate within the whole document
    <chunk>
    {chunk_content}
    </chunk>

    Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
    Answer only with the succinct context and nothing else.
    """

def lambda_handler(event, context):
    """
    Lambda handler function
    """
    logger.debug('input={}'.format(json.dumps(event)))

    s3_client = boto3.client('s3')
    bedrock_runtime = boto3.client(
        service_name='bedrock-runtime',
        region_name='us-east-1'
    )

    input_files = event.get('inputFiles')
    input_bucket = event.get('bucketName')

    if not all([input_files, input_bucket]):
        raise ValueError("Missing required input parameters")

    output_files = []
    for input_file in input_files:
        processed_batches = []
        for batch in input_file.get('contentBatches'):
            input_key = batch.get('key')

            if not input_key:
                raise ValueError("Missing uri in content batch")

            file_content = read_from_s3(s3_client, bucket_name=input_bucket, file_name=input_key)
            print(file_content.get('fileContents'))

            original_document_content = ''.join(
                content.get('contentBody') 
                for content in file_content.get('fileContents') 
                if content
            )

            chunked_content = {
                'fileContents': []
            }
            
            for content in file_content.get('fileContents'):
                content_body = content.get('contentBody', '')
                content_type = content.get('contentType', '')
                content_metadata = content.get('contentMetadata', {})

                # Apply chunking strategy
                chunks = chunk_text(content_body)
                
                for chunk in chunks:
                    prompt = contextual_retrieval_prompt.format(
                        doc_content=original_document_content, 
                        chunk_content=chunk
                    )
                    response_stream = invoke_model_with_response_stream(bedrock_runtime, prompt)
                    chunk_context = ''.join(chunk_text for chunk_text in response_stream if chunk_text)

                    chunked_content['fileContents'].append({
                        "contentBody": chunk_context + "\n\n" + chunk,
                        "contentType": content_type,
                        "contentMetadata": content_metadata,
                    })

            output_key = f"Output/{input_key}"
            write_output_to_s3(s3_client, input_bucket, output_key, chunked_content)
            processed_batches.append({"key": output_key})
            
        output_files.append({
            "originalFileLocation": input_file.get('originalFileLocation'),
            "fileMetadata": {},
            "contentBatches": processed_batches
        })

    return {
        "outputFiles": output_files
    }

The standard chunking strategies values provided by knowledge bases are following: 

**Parameter values:**
 
```
"chunkingStrategy": "FIXED_SIZE | NONE | HIERARCHICAL | SEMANTIC"
```

For implementing our custom logic, we have included an option in the `knowledge_base.py` class for passing a value of `CUSTOM"`. 
If you pass the chunking strategy as `CUSTOM` in this class, it will do the following: 

1. It select the `chunkingStrategy` as `NONE`. 
2. It will add `customTransformationConfiguration` to the `vectorIngestionConfiguration` as follows: 

```
{
...
   "vectorIngestionConfiguration": {
    "customTransformationConfiguration": { 
         "intermediateStorage": { 
            "s3Location": { 
               "uri": "string"
            }
         },
         "transformations": [
            {
               "transformationFunction": {
                  "lambdaConfiguration": {
                     "lambdaArn": "string"
                  }
               },
               "stepToApply": "string" // enum of POST_CHUNKING
            }
         ]
      },
      "chunkingConfiguration": {
         "chunkingStrategy": "NONE"
         ...
   }
}

```


In [None]:
knowledge_base_custom = BedrockKnowledgeBase(
    kb_name=f'{knowledge_base_name_custom}-{suffix}',
    kb_description=knowledge_base_description,
    data_sources=data_source,
    lambda_function_name=lambda_function_name,
    intermediate_bucket_name=intermediate_bucket_name, 
    chunking_strategy = "CUSTOM", 
    suffix = f'{suffix}-c'
)

In [None]:
### Update AWS Lambda IAM role with addiotional permissions to access the bedrock invole api and Lambda function timeout

In [None]:
import boto3
import json

# Initialize AWS clients
lambda_client = boto3.client('lambda')
iam_client = boto3.client('iam')

def update_lambda_timeout():
    try:
        # Update Lambda function configuration
        response = lambda_client.update_function_configuration(
            FunctionName=lambda_function_name,
            Timeout=900  # 15 minutes in seconds
        )
        print(f"Successfully updated Lambda timeout to 15 minutes")
    except Exception as e:
        print(f"Error updating Lambda timeout: {str(e)}")

def get_lambda_role():
    # Get Lambda function configuration
    response = lambda_client.get_function_configuration(
        FunctionName=lambda_function_name
    )
    
    # Extract the role ARN
    role_arn = response['Role']
    role_name = role_arn.split('/')[-1]
    
    print(f"Lambda Role ARN: {role_arn}")
    print(f"Lambda Role Name: {role_name}")
    
    return role_name

def create_bedrock_policy(role_name):
    # Define the policy document for Bedrock access
    bedrock_policy = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Action": [
                    "bedrock:InvokeModelWithResponseStream"
                ],
                "Resource": [
                    "*"
                ]
            }
        ]
    }
    
    # Create the policy
    try:
        response = iam_client.create_policy(
            PolicyName='BedrockClaudeAccess',
            PolicyDocument=json.dumps(bedrock_policy)
        )
        policy_arn = response['Policy']['Arn']
    except iam_client.exceptions.EntityAlreadyExistsException:
        # If policy already exists, get its ARN
        account_id = boto3.client('sts').get_caller_identity()['Account']
        policy_arn = f'arn:aws:iam::{account_id}:policy/BedrockClaudeAccess'
    
    # Attach the policy to the role
    try:
        iam_client.attach_role_policy(
            RoleName=role_name,
            PolicyArn=policy_arn
        )
        print(f"Successfully attached Bedrock policy to role {role_name}")
    except Exception as e:
        print(f"Error attaching policy: {str(e)}")

# Execute the functions
update_lambda_timeout()
role_name = get_lambda_role()
create_bedrock_policy(role_name)

Now start the ingestion job. 

In [None]:
# ensure that the kb is available
time.sleep(30)
# sync knowledge base
knowledge_base_custom.start_ingestion_job()

In [None]:
kb_id_custom = knowledge_base_custom.get_knowledge_base_id()

### 2.2 Test the Knowledge Base
Now the Knowlegde Base is available we can test it out using the [**retrieve**](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve.html) and [**retrieve_and_generate**](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent-runtime/client/retrieve_and_generate.html) functions. 

#### Testing Knowledge Base with Retrieve and Generate API

Let's first test the knowledge base using the retrieve and generate API. With this API, Bedrock takes care of retrieving the necessary references from the knowledge base and generating the final answer using a foundation model from Bedrock.

query = `Provide a summary of consolidated statements of cash flows of Octank Financial for the fiscal years ended December 31, 2019.`

The right response for this query as per ground truth QA pair is: 

```
The cash flow statement for Octank Financial in the year ended December 31, 2019 reveals the following:
- Cash generated from operating activities amounted to $710 million, which can be attributed to a $700 million profit and non-cash charges such as depreciation and amortization.
- Cash outflow from investing activities totaled $240 million, with major expenditures being the acquisition of property, plant, and equipment ($200 million) and marketable securities ($60 million), partially offset by the sale of property, plant, and equipment ($40 million) and maturing marketable securities ($20 million).
- Financing activities resulted in a cash inflow of $350 million, stemming from the issuance of common stock ($200 million) and long-term debt ($300 million), while common stock repurchases ($50 million) and long-term debt payments ($100 million) reduced the cash inflow. 
Overall, Octank Financial experienced a net cash enhancement of $120 million in 2019, bringing their total cash and cash equivalents to $210 million.
```

In [None]:
query = "Provide a summary of consolidated statements of cash flows of Octank Financial for the fiscal years ended December 31, 2019."

In [None]:
time.sleep(10)

response = bedrock_agent_runtime_client.retrieve_and_generate(
    input={
        "text": query
    },
    retrieveAndGenerateConfiguration={
        "type": "KNOWLEDGE_BASE",
        "knowledgeBaseConfiguration": {
            'knowledgeBaseId': kb_id_custom,
            "modelArn": "arn:aws:bedrock:{}::foundation-model/{}".format(region, foundation_model),
            "retrievalConfiguration": {
                "vectorSearchConfiguration": {
                    "numberOfResults":5
                } 
            }
        }
    }
)

print(response['output']['text'],end='\n'*2)

As you can see, with the `RetreiveAndGenerate` API we get the final response directly, now let's observe the citations for `RetreiveAndGenerate` API. Since, our primary focus on this notebook is to observe the retrieved chunks and citations returned by the model while generating the response. When we provide the relevant context to the foundation model alongwith the query, it will most likely generate the high quality response. 

In [None]:
response_custom = response['citations'][0]['retrievedReferences']
print("# of citations or chunks used to generate the response: ", len(response_custom))
citations_rag_print(response_custom)

Let's now retrieve the source information from the knowledge base with the retrieve API.

#### Testing Knowledge Base with Retrieve API
If you need an extra layer of control, you can retrieve the chuncks that best match your query using the retrieve API. In this setup, we can configure the desired number of results and control the final answer with your own application logic. The API then provides you with the matching content, its S3 location, the similarity score and the chunk metadata.

In [None]:
response_custom_ret = bedrock_agent_runtime_client.retrieve(
    knowledgeBaseId=kb_id_custom, 
    nextToken='string',
    retrievalConfiguration={
        "vectorSearchConfiguration": {
            "numberOfResults":5,
        } 
    },
    retrievalQuery={
        'text': query
    }
)
print("# of citations or chunks used to generate the response: ", len(response_custom_ret['retrievalResults']))
response_print(response_custom_ret)

### If you look at the chunks, each chunk includes context+chunk for a given chunk.

In all cases, while evaluating one query, we got the correct response. However, when you are building a RAG application, you need to evaluate with large number of Questions and Answers to figure out the accuracy improvements. In the next step, we will use RAG Assessment (RAGAS) open source framework to evaluate the responses on `your dataset` for the metrics related to evaluating the quality of the context or search results.
We will focus only on 2 metrics: 

1. Context recall
2. Context relevancy

## 3. Evaluating search results using RAG Assessment (RAGAS) framework on your dataset
You can use RAGAS framework to evaluate your results for each chunking strategy. This approach can help you provide factual guidance on which chunking strategy to use for your dataset. 

Below approach will provide you heuristics as to which strategy could be used based on the default parameters recommended by Amazon Bedrock Knowledge Bases. 

In [None]:
print("Standard: ", kb_id_standard)
print("Custom chunking: ", kb_id_custom)

In [None]:
from botocore.client import Config
from langchain.llms.bedrock import Bedrock
from langchain_aws import ChatBedrock
from langchain_aws import BedrockEmbeddings

bedrock_client = boto3.client('bedrock-runtime')

TEXT_GENERATION_MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0"
EVALUATION_MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"

llm_for_evaluation = ChatBedrock(model_id=EVALUATION_MODEL_ID, client=bedrock_client)
bedrock_embeddings = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0", client=bedrock_client)

In [None]:
def retrieve_and_generate(query, kb_id):
    start = time.time()
    response = bedrock_agent_runtime_client.retrieve_and_generate(
        input={
            'text': query
        },
        retrieveAndGenerateConfiguration={
            'type': 'KNOWLEDGE_BASE',
            'knowledgeBaseConfiguration': {
                'knowledgeBaseId': kb_id,
                'modelArn': TEXT_GENERATION_MODEL_ID
            }
        }
    )
    time_spent = time.time() - start
    print(f"[Response]\n{response['output']['text']}\n")
    print(f"[Invocation time]\n{time_spent}\n")

    return response

## Use RAGAS to evaluate RAG

As RAGAS aims to be a reference-free evaluation framework, the required preparations of the evaluation dataset are minimal. You will need to prepare question and ground_truths pairs from which you can prepare the remaining information through inference as shown below. If you are not interested in the context_recall metric, you don’t need to provide the ground_truths information. In this case, all you need to prepare are the questions.

In [None]:
from ragas import SingleTurnSample, EvaluationDataset
from ragas import evaluate
from ragas.metrics import (
    context_recall,
    context_precision,
    answer_correctness
)

#specify the metrics here
metrics = [
    context_recall,
    context_precision,
    answer_correctness
]


questions = [
    "What was the primary reason for the increase in net cash provided by operating activities for Octank Financial in 2021?",
    "In which year did Octank Financial have the highest net cash used in investing activities, and what was the primary reason for this?",
    "What was the primary source of cash inflows from financing activities for Octank Financial in 2021?",
    "Based on the information provided, what can you infer about Octank Financial's overall financial health and growth prospects?"
]
ground_truths = [
    "The increase in net cash provided by operating activities was primarily due to an increase in net income and favorable changes in operating assets and liabilities.",
    "Octank Financial had the highest net cash used in investing activities in 2021, at $360 million. The primary reason for this was an increase in purchases of property, plant, and equipment and marketable securities",
    "The primary source of cash inflows from financing activities for Octank Financial in 2021 was an increase in proceeds from the issuance of common stock and long-term debt.",
    "Based on the information provided, Octank Financial appears to be in a healthy financial position and has good growth prospects. The company has consistently increased its net cash provided by operating activities, indicating strong profitability and efficient management of working capital. Additionally, Octank Financial has been investing in long-term assets, such as property, plant, and equipment, and marketable securities, which suggests plans for future growth and expansion. The company has also been able to finance its growth through the issuance of common stock and long-term debt, indicating confidence from investors and lenders. Overall, Octank Financial's steady increase in cash and cash equivalents over the past three years provides a strong foundation for future growth and investment opportunities."
]

In [None]:
def prepare_eval_dataset(kb_id, questions, ground_truths):
    # Lists to store SingleTurnSample objects
    samples = []
    
    for question, ground_truth in zip(questions, ground_truths):
        # Get response and context from your retrieval system
        response = retrieve_and_generate(question, kb_id)
        answer = response["output"]["text"]
        
        # Process contexts
        contexts = []
        for citation in response["citations"]:
            context_texts = [
                ref["content"]["text"]
                for ref in citation["retrievedReferences"]
                if "content" in ref and "text" in ref["content"]
            ]
            contexts.extend(context_texts)
        
        # Create a SingleTurnSample
        sample = SingleTurnSample(
            user_input=question,
            retrieved_contexts=contexts,
            response=answer,
            reference=ground_truth
        )
        
        # Add the sample to our list
        samples.append(sample)
        
        # Rate limiting
        # time.sleep(10)

    # Create EvaluationDataset from samples
    eval_dataset = EvaluationDataset(samples=samples)
    
    return eval_dataset

In [None]:
contextual_chunking_dataset = prepare_eval_dataset(kb_id_custom, questions, ground_truths)

In [None]:
contextual_chunking_result = evaluate(
    dataset=contextual_chunking_dataset,
    metrics=metrics,
    llm=llm_for_evaluation,
    embeddings=bedrock_embeddings,
)
contextual_chunking_result_df = contextual_chunking_result.to_pandas()

In [None]:
contextual_chunking_result_df

In [None]:
default_chunking_dataset = prepare_eval_dataset(kb_id_standard, questions, ground_truths)

In [None]:
default_chunking_result = evaluate(
    dataset=default_chunking_dataset,
    metrics=metrics,
    llm=llm_for_evaluation,
    embeddings=bedrock_embeddings,
)

default_chunking_result_df = default_chunking_result.to_pandas()

In [None]:
default_chunking_result_df

In [None]:
default_chunking_avg_metrics = default_chunking_result_df[['context_recall', 'context_precision', 'answer_correctness']].mean()
contextual_chunking_avg_metrics = contextual_chunking_result_df[['context_recall', 'context_precision', 'answer_correctness']].mean()

In [None]:
import pandas as pd
comparison_df = pd.DataFrame({
    'Default Chunking': default_chunking_avg_metrics,
    'Contextual Chunking': contextual_chunking_avg_metrics
})

In [None]:
def highlight_max(s):
    is_max = s == s.max()
    return ['background-color: #90EE90' if v else '' for v in is_max]

comparison_df.style.apply(highlight_max, axis=1, subset=['Default Chunking', 'Contextual Chunking'])

## Conclusion

Using Contextrial Retrieval in RAG chunking strategy can effectively improve retrieval accuracy. It preserves better context by solveing the problem of traditional RAG systems destroying context when splitting documents into chunks and adding relevant contextual information to each chunk before embedding. 

<div class="alert alert-block alert-warning">
<b>Note:</b> Remember to delete KB, OSS index and related IAM roles and policies to avoid incurring any charges.
</div>