# Enhancing RAG with Contextual Retrieval

Retrieval Augmented Generation (RAG) enables Claude to leverage your internal knowledge bases, codebases, or any other corpus of documents when providing a response. Enterprises are increasingly building RAG applications to improve workflows in customer support, Q&A over internal company documents, financial & legal analysis, code generation, and much more.

In a [separate guide](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/retrieval_augmented_generation/guide.ipynb), we walked through setting up a basic retrieval system, demonstrated how to evaluate its performance, and then outlined a few techniques to improve performance. In this guide, we present a technique for improving retrieval performance: Contextual Embeddings.

In traditional RAG, documents are typically split into smaller chunks for efficient retrieval. While this approach works well for many applications, it can lead to problems when individual chunks lack sufficient context. Contextual Embeddings solve this problem by adding relevant context to each chunk before embedding. This method improves the quality of each embedded chunk, allowing for more accurate retrieval and thus better overall performance. Averaged across all data sources we tested, Contextual Embeddings reduced the top-20-chunk retrieval failure rate by 35%.

The same chunk-specific context can also be used with BM25 search to further improve retrieval performance. We introduce this technique in the “Contextual BM25” section.

In this guide, we'll demonstrate how to build and optimize a Contextual Retrieval system using a dataset of 9 codebases as our knowledge base. We'll walk through:

1) Setting up a basic retrieval pipeline to establish a baseline for performance.

2) Contextual Embeddings: what it is, why it works.

3) Implementing Contextual Embeddings and demonstrating performance improvements.

4) Contextual BM25: improving performance with *contextual* BM25 hybrid search.

5) Improving performance with reranking,

### Note:

Obtain the [dataset](https://github.com/anthropics/anthropic-cookbook/tree/main/skills/contextual-embeddings/data) from the original implementation in anthropic's cookbook.

Prompt caching for Bedrock is currently not available - you can refer to [semantic caching](https://aws.amazon.com/blogs/database/improve-speed-and-reduce-cost-for-generative-ai-workloads-with-a-persistent-semantic-cache-in-amazon-memorydb/) for a custom implementation

### Evaluation Metrics & Dataset:

We use a pre-chunked dataset of 9 codebases - all of which have been chunked according to a basic character splitting mechanism. Our evaluation dataset contains 248 queries - each of which contains a 'golden chunk.' We'll use a metric called Pass@k to evaluate performance. Pass@k checks whether or not the 'golden document' was present in the first k documents retrieved for each query. Contextual Embeddings in this case helped us to improve Pass@10 performance from ~87% --> ~95%.

You can find the code files and their chunks in `data/codebase_chunks.json` and the evaluation dataset in `data/evaluation_set.jsonl`

## Table of Contents

1) Setup

2) Basic RAG

3) Contextual Embeddings

4) Contextual BM25

5) Reranking

## Setup

We'll need a few libraries, including:

1) `boto3` for invoking anthropic claude 3 on bedrock

2) `cohere-aws` for invoking cohere rerank model hosted on sagemaker endpoint

3) `pandas` and `numpy` for data manipulation and visualization

In [2]:
%pip install boto3 cohere-aws pandas numpy 

### OpenSearch serverless, Sagemaker Jumpstart, Bedrock access

Setup a collection for OpenSearch serverless using instructions [here](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-getting-started.html).  

Launch any of the embedding models from [sagemaker jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html#jumpstart-open-use-studio).  
**bge-m3** is used here for its long context support up to 8192 input tokens.  
You can also use embedding models from bedrock but there's none that support > 2048 input tokens as of this writing.  
You can circumvent this limitation by further breaking up the dataset chunks into 2048 tokens or less.  
**cohere rerank 3** is used as the reranker.

Get access to Bedrock FMs using instructions [here](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html). **claude-3-haiku** is used as the LLM in this example.

[Configure aws credentials](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) with access to invoke OpenSearch, sagemaker jumpstart and Bedrock endpoints if running this locally.  
For Sagemaker Studio notebooks, you can follow the instructions [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) to grant similar access to the sagemaker execution role.

In [1]:
llm_model_id = '<llm model id>' # bedrock model id
embedder_endpoint_name = '<sagemaker endpoint name for embedding model>' # sagemaker endpoint

### Initialize a Vector DB Class

Note that we are using [approximate knn](https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/) here for search, which is less accurate but faster than the original implementation.  
More importantly, the improvements across the different methods used is consistent.


In [2]:
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

credentials = boto3.Session().get_credentials()
region = '<your opensearch region>'
service = 'aoss'
host_name = '<your opensearch endpoint specific to your collection>'
awsauth = AWS4Auth(
    credentials.access_key, 
    credentials.secret_key,
    region, 
    service, 
    session_token=credentials.token
)

# Build the OpenSearch client
open_search_client = OpenSearch(
    hosts=[{'host': host_name, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=300
)

index_body = {
    "settings": {
        "index.knn": True,
        "number_of_shards": 1,
        "knn.algo_param.ef_search": 512,
        "number_of_replicas": 0,
    },
    "mappings": {
        "properties": {
            "embedding": {
                "type": "knn_vector",
                "dimension": 1024,
                "method": {
                    "name": "hnsw",
                    "engine": "faiss"
                }
            },
            "doc_id": {"type":"text"},
            "original_uuid": {"type":"text"},
            "chunk_id": {"type":"text"},
            "original_index": {"type":"integer"},
            "content": {"type":"text"},
        }
    }
}

index_name = "basic-idx"

# Create index
response = open_search_client.indices.create(index_name, body=index_body)
print('\nCreating index:')
print(response)

In [3]:
import os
import pickle
import json
import numpy as np
from typing import List, Dict, Any
from tqdm import tqdm
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

class VectorDB:
    def __init__(self, name: str):
        self.sm_runtime_client = boto3.client('runtime.sagemaker')

        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/vector_db.pkl"
    def _query_endpoint(self,encoded_json, content_type):
        response = self.sm_runtime_client.invoke_endpoint(EndpointName=embedder_endpoint_name, ContentType=content_type, Body=encoded_json)
        return json.loads(response['Body'].read())
    def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)
        
        with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
            for doc in dataset:
                for chunk in doc['chunks']:
                    texts_to_embed.append(chunk['content'])
                    metadata.append({
                        'doc_id': doc['doc_id'],
                        'original_uuid': doc['original_uuid'],
                        'chunk_id': chunk['chunk_id'],
                        'original_index': chunk['original_index'],
                        'content': chunk['content']
                    })
                    pbar.update(1)

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()
        
        print(f"Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        with tqdm(total=len(texts), desc="Embedding chunks") as pbar:
            result = []
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                batch_result = self._query_endpoint(json.dumps(batch).encode('utf-8'), 'application/x-text')['embedding']
                result.extend(batch_result)
                pbar.update(len(batch))
        
        self.embeddings = result
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.query_endpoint(json.dumps([query]).encode('utf-8'), 'application/x-text')['embedding']
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        query_json = {
            "size": k,
            "query": {
                "knn": {
                    "embedding": {
                        "vector": query_embedding[0], 
                        "k": k}
                    },
            }
        }

        response = open_search_client.search(
            body = query_json,
            index = index_name
        )

        top_hits = response['hits']['hits']

        results = list(map(lambda x: {'metadata':{'doc_id':x['_source']['doc_id'],'original_uuid':x['_source']['original_uuid'],'chunk_id':x['_source']['chunk_id'],'original_index':x['_source']['original_index'],'content':x['_source']['content']},'similarity':x['_score']},top_hits))

        return results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

    def validate_embedded_chunks(self):
        unique_contents = set()
        for meta in self.metadata:
            unique_contents.add(meta['content'])
    
        print(f"Validation results:")
        print(f"Total embedded chunks: {len(self.metadata)}")
        print(f"Unique embedded contents: {len(unique_contents)}")
    
        if len(self.metadata) != len(unique_contents):
            print("Warning: There may be duplicate chunks in the embedded data.")
        else:
            print("All embedded chunks are unique.")

In [4]:
# Load your transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the VectorDB
base_db = VectorDB("base_db")

# Load and process the data
base_db.load_data(transformed_dataset)

Loading vector database from disk.


In [None]:
# bulk index the documents into AOSS
docs = []
for i in range(len(base_db.embeddings)):
    docs.append({'index':{ "_index": index_name}})
    docs.append({**base_db.metadata[i],'embedding':base_db.embeddings[i]})
    
open_search_client.bulk(body=docs)
    

## Basic RAG

To get started, we'll set up a basic RAG pipeline using a bare bones approach. This is sometimes called 'Naive RAG' by many in the industry. A basic RAG pipeline includes the following 3 steps:

1) Chunk documents by heading - containing only the content from each subheading

2) Embed each document

3) Use Cosine similarity to retrieve documents in order to answer query

In [5]:
import json
from typing import List, Dict, Any, Callable, Union
from tqdm import tqdm

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Load JSONL file and return a list of dictionaries."""
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]

def evaluate_retrieval(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:
    total_score = 0
    total_queries = len(queries)
    
    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']
        
        # Find all golden chunk contents
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if not golden_doc:
                print(f"Warning: Golden document not found for UUID {doc_uuid}")
                continue
            
            golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
            if not golden_chunk:
                print(f"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}")
                continue
            
            golden_contents.append(golden_chunk['content'].strip())
        
        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue
        
        retrieved_docs = retrieval_function(query, db, k=k)
        
        # Count how many golden chunks are in the top k retrieved documents
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['metadata'].get('original_content', doc['metadata'].get('content', '')).strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break
        
        query_score = chunks_found / len(golden_contents)
        total_score += query_score
    
    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    return {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }

def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
    """
    Retrieve relevant documents using either VectorDB or ContextualVectorDB.
    
    :param query: The query string
    :param db: The VectorDB or ContextualVectorDB instance
    :param k: Number of top results to retrieve
    :return: List of retrieved documents
    """
    return db.search(query, k=k)

def evaluate_db(db, original_jsonl_path: str, k):
    # Load the original JSONL data for queries and ground truth
    original_data = load_jsonl(original_jsonl_path)
    
    # Evaluate retrieval
    results = evaluate_retrieval(original_data, retrieve_base, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Total Score: {results['average_score']}")
    print(f"Total queries: {results['total_queries']}")

In [12]:
results5 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db(base_db, 'data/evaluation_set.jsonl', 20)

Evaluating retrieval: 100%|██████████| 248/248 [00:08<00:00, 30.62it/s]


Pass@5: 76.15%
Total Score: 0.7614727342549924
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:08<00:00, 28.44it/s]


Pass@10: 83.54%
Total Score: 0.8354454685099846
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [00:07<00:00, 31.04it/s]

Pass@20: 88.71%
Total Score: 0.8870967741935484
Total queries: 248





## Contextual Embeddings

With basic RAG, each embedded chunk contains a potentially useful piece of information, but these chunks lack context. With Contextual Embeddings, we create a variation on the embedding itself by adding more context to each text chunk before embedding it. Specifically, we use Claude to create a concise context that explains the chunk using the context of the overall document. In the case of our codebases dataset, we can provide both the chunk and the full file that each chunk was found within to an LLM, then produce the context. Then, we will combine this 'context' and the raw text chunk together into a single text block prior to creating each embedding.

In [6]:
index_name = "contextual-idx"

index_body = {
    "settings": {
        "index.knn": True,
        "number_of_shards": 1,
        "knn.algo_param.ef_search": 512,
        "number_of_replicas": 0,
    },
    "mappings": {
        "properties": {
            "embedding": {
                "type": "knn_vector",
                "dimension": 1024,
                "method": {
                    "name": "hnsw",
                    "engine": "faiss"
                }
            },
            "doc_id": {"type":"text"},
            "original_uuid": {"type":"text"},
            "chunk_id": {"type":"text"},
            "original_index": {"type":"integer"},
            "original_content": {"type":"text"},
            "contextualized_content": {"type":"text"}
        }
    }
}

response = open_search_client.indices.create(index_name, body=index_body)

print('\nCreating index:')
print(response)

In [7]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

import boto3
client = boto3.client('bedrock-runtime')
def situate_context(doc: str, chunk: str) -> str:
    messages=[
        {
            "role": "user", 
            "content": [
                {
                    "type": "text",
                    "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                },
                {
                    "type": "text",
                    "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                }
            ]
        }
    ]
    body=json.dumps(
        {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": 1024,
            "messages": messages,
            "temperature": 0.0,
        }  
    ) 
    response = client.invoke_model(modelId=llm_model_id,body=body)
    return response

In [10]:
import os
import pickle
import json
import numpy as np
# import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
# import anthropic
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed

class ContextualVectorDB:
    def __init__(self, name: str):
        self.embedding_client = boto3.client('runtime.sagemaker')
        self.llm_client = boto3.client('bedrock-runtime')
        self.name = name
        self.db_path = f"./data/{name}/contextual_vector_db.pkl"
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}

        self.token_counts = {
            'input': 0,
            'output': 0,
            'cache_read': 0,
            'cache_creation': 0
        }
        self.token_lock = threading.Lock()

    def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:
        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        Here is the chunk we want to situate within the whole document
        <chunk>
        {chunk_content}
        </chunk>

        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
        Answer only with the succinct context and nothing else.
        """

        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ]
        body=json.dumps(
            {
                "anthropic_version": "bedrock-2023-05-31",
                "max_tokens": 1024,
                "messages": messages,
                "temperature": 0.0,
            }  
        ) 
        response = self.llm_client.invoke_model(modelId=llm_model_id,body=body)
        body = json.loads(response.get('body').read())
        return body['content'][0]['text'],body['usage']

    def load_data(self, dataset: List[Dict[str, Any]], parallel_threads: int = 1):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)

        def process_chunk(doc, chunk):
            contextualized_text, usage = self.situate_context(doc['content'], chunk['content'])
            with self.token_lock:
                self.token_counts['input'] += usage['input_tokens']
                self.token_counts['output'] += usage['output_tokens']
            
            return {
                'text_to_embed': f"{chunk['content']}\n\n{contextualized_text}",
                'metadata': {
                    'doc_id': doc['doc_id'],
                    'original_uuid': doc['original_uuid'],
                    'chunk_id': chunk['chunk_id'],
                    'original_index': chunk['original_index'],
                    'original_content': chunk['content'],
                    'contextualized_content': contextualized_text
                }
            }

        print(f"Processing {total_chunks} chunks with {parallel_threads} threads")
        with ThreadPoolExecutor(max_workers=parallel_threads) as executor:
            futures = []
            for doc in dataset:
                for chunk in doc['chunks']:
                    futures.append(executor.submit(process_chunk, doc, chunk))
            
            for future in tqdm(as_completed(futures), total=total_chunks, desc="Processing chunks"):
                result = future.result()
                texts_to_embed.append(result['text_to_embed'])
                metadata.append(result['metadata'])

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        print(f"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
        print(f"Total input tokens without caching: {self.token_counts['input']}")
        print(f"Total output tokens: {self.token_counts['output']}")
    
    def _query_endpoint(self,encoded_json, content_type):
        response = self.embedding_client.invoke_endpoint(EndpointName=embedder_endpoint_name, ContentType=content_type, Body=encoded_json)
        return json.loads(response['Body'].read())

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        result = [
            self._query_endpoint(json.dumps(texts[i : i + batch_size]).encode('utf-8'), 'application/x-text')['embedding']
            for i in range(0, len(texts), batch_size)
        ]

        arr = []
        for res in result:
            arr = arr + res
        self.embeddings = arr
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self._query_endpoint(json.dumps([query]).encode('utf-8'), 'application/x-text')['embedding']
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        query_json = {
            "size": k,
            "query": {
                "knn": {
                    "embedding": {
                        "vector": query_embedding[0], 
                        "k": k}
                    },
            }
        }

        response = open_search_client.search(
            body = query_json,
            index = index_name
        )

        top_hits = response['hits']['hits']

        results = list(map(lambda x: {'metadata':{'doc_id':x['_source']['doc_id'],'original_uuid':x['_source']['original_uuid'],'chunk_id':x['_source']['chunk_id'],'original_index':x['_source']['original_index'],'original_content':x['_source']['original_content'],'contextualized_content':x['_source']['contextualized_content']},'similarity':x['_score']},top_hits))

        return results

    def search_hybrid(self, query: str, k: int = 20,semantic_weight: float = 0.8, bm25_weight: float = 0.2) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self._query_endpoint(json.dumps([query]).encode('utf-8'), 'application/x-text')['embedding']
            self.query_cache[query] = query_embedding

        query_json = {
            "size": k,
            "query": {
                "bool": {
                    "should": [{
                        "knn": {
                            "embedding": {
                                "vector": query_embedding[0],
                                "k": k,
                                "boost": semantic_weight
                            }
                        }},{
                        "multi_match": {
                            "query": query,
                            "fields": ["original_content", "contextualized_content"],
                            "boost":bm25_weight
                        }}
                    ]
                }
            }
        }

        response = open_search_client.search(
            body = query_json,
            index = index_name
        )

        top_hits = response['hits']['hits']

        results = list(map(lambda x: {'metadata':{'doc_id':x['_source']['doc_id'],'original_uuid':x['_source']['original_uuid'],'chunk_id':x['_source']['chunk_id'],'original_index':x['_source']['original_index'],'original_content':x['_source']['original_content'],'contextualized_content':x['_source']['contextualized_content']},'similarity':x['_score']},top_hits))

        return results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])


In [11]:
# Load the transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the ContextualVectorDB
contextual_db = ContextualVectorDB("my_contextual_db")

# Load and process the data
contextual_db.load_data(transformed_dataset, parallel_threads=5)

Loading vector database from disk.


In [None]:
# bulk index documents into AOSS
docs = []
for i in range(len(contextual_db.embeddings)):
    docs.append({'index':{ "_index": index_name}})
    docs.append({**contextual_db.metadata[i],'embedding':contextual_db.embeddings[i]})
    
open_search_client.bulk(body=docs)
    

In [12]:
r5 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 5)
r10 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 10)
r20 = evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 20)

Pass@5: 85.66%
Total Score: 0.8565668202764978
Total queries: 248
Pass@10: 89.78%
Total Score: 0.8978494623655914
Total queries: 248
Pass@20: 93.15%
Total Score: 0.9314516129032258
Total queries: 248


## Contextual BM25

Contextual embeddings is an improvement on traditional semantic search RAG, but we can improve performance further. In this section we'll show you how you can use contextual embeddings and *contextual* BM25 together. While you can see performance gains by pairing these techniques together without the context, adding context to these methods reduces the top-20-chunk retrieval failure rate by 42%.

BM25 is a probabilistic ranking function that improves upon TF-IDF. It scores documents based on query term frequency, while accounting for document length and term saturation. BM25 is widely used in modern search engines for its effectiveness in ranking relevant documents.

One difference between a typical BM25 search and what we'll do in this section is that, for each chunk, we'll run each BM25 search on both the chunk content and the additional context that we generated in the previous section. From there, we'll use a technique called reciprocal rank fusion to merge the results from our BM25 search with our semantic search results. This allows us to perform a hybrid search across both our BM25 corpus and vector DB to return the most optimal documents for a given query.

In the function below, we allow you the option to add weightings to the semantic search and BM25 search documents as you merge them with Reciprocal Rank Fusion. By default, we set these to 0.8 for the semantic search results and 0.2 to the BM25 results. We'd encourage you to experiment with different values here.

In [13]:
import os
import json
from typing import List, Dict, Any

class OpenSearchBM25:
    def __init__(self):
        self.index_name = index_name

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        # self.es_client.indices.refresh(index=self.index_name)  # Force refresh before each search
        search_body = {
            "query": {
                "multi_match": {
                    "query": query,
                    "fields": ["original_content", "contextualized_content"],
                }
            },
            "size": k,
        }
        response = open_search_client.search(index=self.index_name, body=search_body)
        return [
            {
                "doc_id": hit["_source"]["doc_id"],
                "original_index": hit["_source"]["original_index"],
                "original_content": hit["_source"]["original_content"],
                "contextualized_content": hit["_source"]["contextualized_content"],
                "score": hit["_score"],
            }
            for hit in response["hits"]["hits"]
        ]

def retrieve_advanced(query: str, db: ContextualVectorDB, db_bm25: OpenSearchBM25,k: int, semantic_weight: float = 0.8, bm25_weight: float = 0.2):
    num_chunks_to_recall = 150

    # Semantic search
    semantic_results = db.search(query, k=num_chunks_to_recall)
    ranked_chunk_ids = [(result['metadata']['doc_id'], result['metadata']['original_index']) for result in semantic_results]

    # BM25 search using OpenSearch
    bm25_results = db_bm25.search(query, k=num_chunks_to_recall)
    ranked_bm25_chunk_ids = [(result['doc_id'], result['original_index']) for result in bm25_results]

    # Combine results
    chunk_ids = list(set(ranked_chunk_ids + ranked_bm25_chunk_ids))
    chunk_id_to_score = {}

    # Initial scoring with weights
    for chunk_id in chunk_ids:
        score = 0
        if chunk_id in ranked_chunk_ids:
            index = ranked_chunk_ids.index(chunk_id)
            score += semantic_weight * (1 / (index + 1))  # Weighted 1/n scoring for semantic
        if chunk_id in ranked_bm25_chunk_ids:
            index = ranked_bm25_chunk_ids.index(chunk_id)
            score += bm25_weight * (1 / (index + 1))  # Weighted 1/n scoring for BM25
        chunk_id_to_score[chunk_id] = score

    # Sort chunk IDs by their scores in descending order
    sorted_chunk_ids = sorted(
        chunk_id_to_score.keys(), key=lambda x: (chunk_id_to_score[x], x[0], x[1]), reverse=True
    )

    # Assign new scores based on the sorted order
    for index, chunk_id in enumerate(sorted_chunk_ids):
        chunk_id_to_score[chunk_id] = 1 / (index + 1)

    # Prepare the final results
    final_results = []
    semantic_count = 0
    bm25_count = 0
    for chunk_id in sorted_chunk_ids[:k]:
        chunk_metadata = next(chunk for chunk in db.metadata if chunk['doc_id'] == chunk_id[0] and chunk['original_index'] == chunk_id[1])
        is_from_semantic = chunk_id in ranked_chunk_ids
        is_from_bm25 = chunk_id in ranked_bm25_chunk_ids
        final_results.append({
            'chunk': chunk_metadata,
            'score': chunk_id_to_score[chunk_id],
            'from_semantic': is_from_semantic,
            'from_bm25': is_from_bm25
        })
        
        if is_from_semantic and not is_from_bm25:
            semantic_count += 1
        elif is_from_bm25 and not is_from_semantic:
            bm25_count += 1
        else:  # it's in both
            semantic_count += 0.5
            bm25_count += 0.5

    return final_results, semantic_count, bm25_count

def evaluate_db_advanced(db: ContextualVectorDB, original_jsonl_path: str, k: int):
    original_data = load_jsonl(original_jsonl_path)
    db_bm25 = OpenSearchBM25()
    # Warm-up queries
    warm_up_queries = original_data[:10]
    for query_item in warm_up_queries:
        _ = retrieve_advanced(query_item['query'], db, db_bm25, k)
    
    total_score = 0
    total_semantic_count = 0
    total_bm25_count = 0
    total_results = 0
    
    for query_item in tqdm(original_data, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']
        
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if golden_doc:
                golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
                if golden_chunk:
                    golden_contents.append(golden_chunk['content'].strip())
        
        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue
        
        retrieved_docs, semantic_count, bm25_count = retrieve_advanced(query, db, db_bm25, k)
        
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['chunk']['original_content'].strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break
        
        query_score = chunks_found / len(golden_contents)
        total_score += query_score
        
        total_semantic_count += semantic_count
        total_bm25_count += bm25_count
        total_results += len(retrieved_docs)
    
    total_queries = len(original_data)
    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    
    semantic_percentage = (total_semantic_count / total_results) * 100 if total_results > 0 else 0
    bm25_percentage = (total_bm25_count / total_results) * 100 if total_results > 0 else 0
    
    results = {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }
    
    print(f"Pass@{k}: {pass_at_n:.2f}%")
    print(f"Average Score: {average_score:.2f}")
    print(f"Total queries: {total_queries}")
    print(f"Percentage of results from semantic search: {semantic_percentage:.2f}%")
    print(f"Percentage of results from BM25: {bm25_percentage:.2f}%")
    
    return results, {"semantic": semantic_percentage, "bm25": bm25_percentage}

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]

# lexical and semantic search in 1 pass through AOSS query
def retrieve_advanced_hybrid(query: str, db: ContextualVectorDB,k: int, semantic_weight: float = 0.8, bm25_weight: float = 0.2):
    num_chunks_to_recall = 150

    # hybrid search
    results = db.search_hybrid(query,k=num_chunks_to_recall,semantic_weight=semantic_weight,bm25_weight=bm25_weight)
    ranked_results = [result['metadata']['original_content'] for result in results]

    return ranked_results

# lexical and semantic search in 1 pass through AOSS query
def evaluate_db_hybrid(db: ContextualVectorDB, original_jsonl_path: str, k: int):
    original_data = load_jsonl(original_jsonl_path)
    total_score = 0
    total_results = 0
    for query_item in tqdm(original_data, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']
        
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if golden_doc:
                golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
                if golden_chunk:
                    golden_contents.append(golden_chunk['content'].strip())
        
        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue
        
        retrieved_docs = retrieve_advanced_hybrid(query, db, k)
        
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                if doc.strip() == golden_content:
                    chunks_found += 1
                    break
        
        query_score = chunks_found / len(golden_contents)
        total_score += query_score
        
        total_results += len(retrieved_docs)
    
    total_queries = len(original_data)
    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    
    results = {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }
    
    print(f"Pass@{k}: {pass_at_n:.2f}%")
    print(f"Average Score: {average_score:.2f}")
    print(f"Total queries: {total_queries}")
    
    return results

In [14]:
results5 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 20)

Evaluating retrieval: 100%|██████████| 248/248 [04:34<00:00,  1.11s/it]


Pass@5: 86.97%
Average Score: 0.87
Total queries: 248
Percentage of results from semantic search: 58.43%
Percentage of results from BM25: 41.57%


Evaluating retrieval: 100%|██████████| 248/248 [04:16<00:00,  1.03s/it]


Pass@10: 89.73%
Average Score: 0.90
Total queries: 248
Percentage of results from semantic search: 62.80%
Percentage of results from BM25: 37.20%


Evaluating retrieval: 100%|██████████| 248/248 [04:48<00:00,  1.16s/it]

Pass@20: 93.25%
Average Score: 0.93
Total queries: 248
Percentage of results from semantic search: 65.94%
Percentage of results from BM25: 34.06%





Alternatively, combine lexical and semantic search directly through OpenSearch's query dsl using boosting as a substitute for weighted reciprocal rank fusion. Notice the results are much worse than before.

In [15]:
results5 = evaluate_db_hybrid(contextual_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db_hybrid(contextual_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db_hybrid(contextual_db, 'data/evaluation_set.jsonl', 20)

Evaluating retrieval: 100%|██████████| 248/248 [02:27<00:00,  1.68it/s]


Pass@5: 62.38%
Average Score: 0.62
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [02:34<00:00,  1.61it/s]


Pass@10: 71.59%
Average Score: 0.72
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [02:20<00:00,  1.77it/s]

Pass@20: 78.57%
Average Score: 0.79
Total queries: 248





## Adding a Re-Ranking Step

If you want to improve performance further, we recommend adding a re-ranking step. When using a re-ranker, you can retrieve more documents initially from your vector store, then use your re-ranker to select a subset of these documents. One common technique is to use re-ranking as a way to implement high precision hybrid search. You can use a combination of semantic search and keyword based search in your initial retrieval step (as we have done earlier in this guide), then use a re-ranking step to choose only the k most relevant docs from a combined list of documents returned by your semantic search and keyword search systems.

Below, we'll demonstrate only the re-ranking step (skipping the hybrid search technique for now). You'll see that we retrieve 10x the number of documents than the number of final k documents we want to retrieve, then use a re-ranking model from Cohere to select the 10 most relevant results from that list. Adding the re-ranking step delivers a modest additional gain in performance. In our case, Pass@10 improves from 92.81% --> 94.79%.

As above, launch the cohere rerank model through [sagemaker jumpstart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html#jumpstart-open-use-studio). This requires a marketplace subscription as it is proprietary.

In [11]:
reranker_enpoint_name = '<your sagemaker endpoint name for reranker>' # sagemaker endpoint

In [12]:
from cohere_aws import Client
from typing import List, Dict, Any, Callable
import json
from tqdm import tqdm

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]

def chunk_to_content(chunk: Dict[str, Any]) -> str:
    original_content = chunk['metadata']['original_content']
    contextualized_content = chunk['metadata']['contextualized_content']
    return f"{original_content}\n\nContext: {contextualized_content}" 

def retrieve_rerank(query: str, db, k: int) -> List[Dict[str, Any]]:
    co = Client(endpoint_name=reranker_enpoint_name)
    
    # Retrieve more results than we normally would
    semantic_results = db.search(query, k=k*10)
    
    # Extract documents for reranking, using the contextualized content
    documents = [chunk_to_content(res) for res in semantic_results]

    response = co.rerank(
        query=query,
        documents=documents,
        top_n=k
    )
    time.sleep(0.1)
    
    final_results = []
    for r in response.results:
        original_result = semantic_results[r.index]
        final_results.append({
            "chunk": original_result['metadata'],
            "score": r.relevance_score
        })
    
    return final_results

def evaluate_retrieval_rerank(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:
    total_score = 0
    total_queries = len(queries)
    
    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']
        
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if golden_doc:
                golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
                if golden_chunk:
                    golden_contents.append(golden_chunk['content'].strip())
        
        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue
        
        retrieved_docs = retrieval_function(query, db, k)
        
        chunks_found = 0
        for golden_content in golden_contents:
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['chunk']['original_content'].strip()
                if retrieved_content == golden_content:
                    chunks_found += 1
                    break
        
        query_score = chunks_found / len(golden_contents)
        total_score += query_score
    
    average_score = total_score / total_queries
    pass_at_n = average_score * 100
    return {
        "pass_at_n": pass_at_n,
        "average_score": average_score,
        "total_queries": total_queries
    }

def evaluate_db_advanced(db, original_jsonl_path, k):
    original_data = load_jsonl(original_jsonl_path)
    
    def retrieval_function(query, db, k):
        return retrieve_rerank(query, db, k)
    
    results = evaluate_retrieval_rerank(original_data, retrieval_function, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Average Score: {results['average_score']}")
    print(f"Total queries: {results['total_queries']}")
    return results

In [13]:
results5 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 5)
results10 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 10)
results20 = evaluate_db_advanced(contextual_db, 'data/evaluation_set.jsonl', 20)

Evaluating retrieval: 100%|██████████| 248/248 [08:22<00:00,  2.03s/it]


Pass@5: 90.24%
Average Score: 0.9023617511520737
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [07:17<00:00,  1.76s/it]


Pass@10: 93.08%
Average Score: 0.9307795698924731
Total queries: 248


Evaluating retrieval: 100%|██████████| 248/248 [09:03<00:00,  2.19s/it]

Pass@20: 94.66%
Average Score: 0.9465725806451613
Total queries: 248



