# Advanced Retrieval: Contextual Embeddings

Claude excels at a wide range of tasks, but it may struggle with queries specific to your unique business context. This is where Retrieval Augmented Generation (RAG) becomes invaluable. RAG enables Claude to leverage your internal knowledge bases, codebases, or any other corpus of documents when providing a response. This significantly enhances the model's ability on domain specific tasks. Enterprises are increasingly building RAG applications to improve workflows in customer support, Q&A over internal company documents, financial & legal analysis, code generation, and much more.

In a [separate guide](https://github.com/anthropics/anthropic-cookbook/blob/main/skills/retrieval_augmented_generation/guide.ipynb ), we walk through setting up a basic retrieval pipeline, evaluating its performance, and then systematically improving it according to best practices. In this guide, we'll be presenting a new approach to RAG: Contextual Retrieval. This method improves the quality of each embedded chunk, allowing for more accurate semantic similarity searches and thus better overall performance.

In this guide, we'll demonstrate how to build and optimize a RAG system using the Anthropic documentation as our knowledge base. We'll walk through:

1) Setting up a basic retrieval pipeline to establish a baseline for performance.

2) Contextual Retrieval: what it is, why it works, and how prompt caching makes it practical for production use cases.

3) Implementing contextual retrieval and demonstrating performance improvements.

### Evaluation Metrics & Dataset:

In this guide, we use a pre-chunked dataset of 9 codebases - all of which have been chunked according to a basic character splitting mechanism. Our evaluation dataset contains 248 queries - each of which contains a 'golden chunk.' We'll use a metric called Pass@k to evaluate performance. Pass@k checks whether or not the 'golden document' was present in the first `k` documents retrieved for each query. Contextual retrieval in this case improved Pass@10 performance from 83.47% --> 91.13% without changing any other aspect of our system.

You can find the code files and their chunks in `data/codebase_chunks.json` and the evaluation dataset in `data/evaluation_set.jsonl`

#### Note:

[Prompt caching](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) is helpful in managing costs when using this retrieval method. This feature is currently available on Anthropic's 1P API, and is coming soon to our 3P partner environments like AWS Bedrock and GCP Vertex. We know that many of our customers leverage AWS Knowledge Bases and GCP Vertex AI APIs when building RAG solutions, and we're confident that this method can be used on either platform with a bit of customization. Consider reaching out to Anthropic or your AWS/GCP account team for guidance on this!

## Table of Contents

1) Setup

2) Basic RAG

3) Contextual Retrieval

## Setup

We'll need a few libraries, including:

1) `anthropic` - to interact with Claude

2) `voyageai` - to generate high quality embeddings

3) `pandas`, `numpy`, `matplotlib`, and `scikit-learn` for data manipulation and visualization


You'll also need API keys from [Anthropic](https://www.anthropic.com/) and [Voyage AI](https://www.voyageai.com/)

In [None]:
!pip install anthropic
!pip install voyageai
!pip install pandas
!pip install numpy

In [1]:
import os

os.environ['VOYAGE_API_KEY'] = "YOUR KEY HERE"
os.environ['ANTHROPIC_API_KEY'] = "YOUR KEY HERE"

In [2]:
import anthropic

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)

### Initialize a Vector DB Class

In this example, we're using an in-memory vector DB, but for a production application, you may want to use a hosted solution. 



In [4]:
import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm

class VectorDB:
    def __init__(self, name: str, api_key = None):
        if api_key is None:
            api_key = os.getenv("VOYAGE_API_KEY")
        self.client = voyageai.Client(api_key=api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/vector_db.pkl"

    def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)
        
        with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
            for doc in dataset:
                for chunk in doc['chunks']:
                    texts_to_embed.append(chunk['content'])
                    metadata.append({
                        'doc_id': doc['doc_id'],
                        'original_uuid': doc['original_uuid'],
                        'chunk_id': chunk['chunk_id'],
                        'original_index': chunk['original_index'],
                        'content': chunk['content']
                    })
                    pbar.update(1)

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()
        
        print(f"Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        with tqdm(total=len(texts), desc="Embedding chunks") as pbar:
            result = []
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                batch_result = self.client.embed(batch, model="voyage-2").embeddings
                result.extend(batch_result)
                pbar.update(len(batch))
        
        self.embeddings = result
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]
        
        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

    def validate_embedded_chunks(self):
        unique_contents = set()
        for meta in self.metadata:
            unique_contents.add(meta['content'])
    
        print(f"Validation results:")
        print(f"Total embedded chunks: {len(self.metadata)}")
        print(f"Unique embedded contents: {len(unique_contents)}")
    
        if len(self.metadata) != len(unique_contents):
            print("Warning: There may be duplicate chunks in the embedded data.")
        else:
            print("All embedded chunks are unique.")

In [5]:
# Load your transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the VectorDB
base_db = VectorDB("base_db")

# Load and process the data
base_db.load_data(transformed_dataset)

Processing chunks: 100%|██████████| 737/737 [00:00<00:00, 1102032.82it/s]
Embedding chunks: 100%|██████████| 737/737 [00:13<00:00, 55.43it/s]

Vector database loaded and saved. Total chunks processed: 737





## Basic RAG

To get started, we'll set up a basic RAG pipeline using a bare bones approach. This is sometimes called 'Naive RAG' by many in the industry. A basic RAG pipeline includes the following 3 steps:

1) Chunk documents by heading - containing only the content from each subheading

2) Embed each document

3) Use Cosine similarity to retrieve documents in order to answer query

In [14]:
import json
from typing import List, Dict, Any, Callable, Union
from tqdm import tqdm

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Load JSONL file and return a list of dictionaries."""
    with open(file_path, 'r') as file:
        return [json.loads(line) for line in file]


def evaluate_retrieval(queries: List[Dict[str, Any]], retrieval_function: Callable, db, k: int = 20) -> Dict[str, float]:
    correct_retrievals = 0
    total_queries = len(queries)
    
    for query_item in tqdm(queries, desc="Evaluating retrieval"):
        query = query_item['query']
        golden_chunk_uuids = query_item['golden_chunk_uuids']
        
        # Find all golden chunk contents
        golden_contents = []
        for doc_uuid, chunk_index in golden_chunk_uuids:
            golden_doc = next((doc for doc in query_item['golden_documents'] if doc['uuid'] == doc_uuid), None)
            if not golden_doc:
                print(f"Warning: Golden document not found for UUID {doc_uuid}")
                continue
            
            golden_chunk = next((chunk for chunk in golden_doc['chunks'] if chunk['index'] == chunk_index), None)
            if not golden_chunk:
                print(f"Warning: Golden chunk not found for index {chunk_index} in document {doc_uuid}")
                continue
            
            golden_contents.append(golden_chunk['content'].strip())
        
        if not golden_contents:
            print(f"Warning: No golden contents found for query: {query}")
            continue
        
        retrieved_docs = retrieval_function(query, db, k=k)
        
        # Check if all golden chunks are in the top k retrieved documents
        all_chunks_found = True
        for golden_content in golden_contents:
            chunk_found = False
            for doc in retrieved_docs[:k]:
                retrieved_content = doc['metadata'].get('original_content', doc['metadata'].get('content', '')).strip()
                if retrieved_content == golden_content:
                    chunk_found = True
                    break
            if not chunk_found:
                all_chunks_found = False
                break
        
        if all_chunks_found:
            correct_retrievals += 1
    
    pass_at_n = (correct_retrievals / total_queries) * 100
    return {
        "pass_at_n": pass_at_n,
        "correct_retrievals": correct_retrievals,
        "total_queries": total_queries
    }

def retrieve_base(query: str, db, k: int = 20) -> List[Dict[str, Any]]:
    """
    Retrieve relevant documents using either VectorDB or ContextualVectorDB.
    
    :param query: The query string
    :param db: The VectorDB or ContextualVectorDB instance
    :param k: Number of top results to retrieve
    :return: List of retrieved documents
    """
    return db.search(query, k=k)

def evaluate_db(db, original_jsonl_path: str, k):
    # Load the original JSONL data for queries and ground truth
    original_data = load_jsonl(original_jsonl_path)
    
    # Evaluate retrieval
    results = evaluate_retrieval(original_data, retrieve_base, db, k)
    print(f"Pass@{k}: {results['pass_at_n']:.2f}%")
    print(f"Correct retrievals: {results['correct_retrievals']}")
    print(f"Total queries: {results['total_queries']}")

In [15]:
evaluate_db(base_db, 'data/evaluation_set.jsonl', 10)

Evaluating retrieval: 100%|██████████| 248/248 [00:06<00:00, 36.82it/s]

Pass@10: 83.47%
Correct retrievals: 207
Total queries: 248





## Contextual Retrieval

Contextual retrieval is all about creating embedding representations that contain more context than any individual chunk would normally have. When you create embeddings using Basic RAG, you simply embed each text chunk directly, then retrieve one of those chunks by embedding the input doc and retrieving a list of chunks by similarity.

With contextual retrieval, we create a variation on the embedding itself by adding more context to each text chunk before embedding it. Specifically, we use Claude to create a block of text that provides additional context about each chunk in the context of the broader document that this chunk sits within. In the case of our codebases dataset, each chunk lives within a given full code file. We can provide both the chunk and the full file to an LLM, then produce this updated context. Then, we will combine this 'context' and the raw text chunk together into a single text block prior to creating each embedding.

### Is this Efficient?

This type of work is going to happen at ingestion time. It's a cost you'll pay once when you store each document (and occasionally again if you have a knowledge base that updates over time). There are many approaches like hypothetical document embeddings (HyDE) which involve performing steps to improve the representation of the query prior to executing a search. These techniques have shown to be somewhat effective, but they add significant latency at runtime.

Another thing that makes Contextual Retrieval much more efficient is prompt caching. With prompt caching, you avoid the need to pay full price as you add context to each document. With Anthropic's prompt caching feature, you save 90% on the cost of input tokens when they are read from cache. Because the embedding process typically happens all at once, you are likely to have a frequent number of cache hits and thus realize significant savings. When you load data into your ContextualVectorDB below, you'll see in logs just how big this impact is. 


In [9]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
Answer only with the succinct context and nothing else.
"""

def situate_context(doc: str, chunk: str) -> str:
    response = client.beta.prompt_caching.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1024,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral"}
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

jsonl_data = load_jsonl('data/evaluation_set.jsonl')
# Example usage
doc_content = jsonl_data[0]['golden_documents'][0]['content']
chunk_content = jsonl_data[0]['golden_chunks'][0]['content']

response = situate_context(doc_content, chunk_content)
print(f"Situated context: {response.content[0].text}")

# Print cache performance metrics
print(f"Input tokens: {response.usage.input_tokens}")
print(f"Output tokens: {response.usage.output_tokens}")
print(f"Cache creation input tokens: {response.usage.cache_creation_input_tokens}")
print(f"Cache read input tokens: {response.usage.cache_read_input_tokens}")

Situated context: This chunk describes the `DiffExecutor` struct, which is an executor for differential fuzzing. It wraps two executors that are run sequentially with the same input, and also runs the secondary executor in the `run_target` method.
Input tokens: 366
Output tokens: 55
Cache creation input tokens: 3046
Cache read input tokens: 0


In [10]:
import os
import pickle
import json
import numpy as np
import voyageai
from typing import List, Dict, Any
from tqdm import tqdm
import anthropic

class ContextualVectorDB:
    def __init__(self, name: str, voyage_api_key=None, anthropic_api_key=None):
        if voyage_api_key is None:
            voyage_api_key = os.getenv("VOYAGE_API_KEY")
        if anthropic_api_key is None:
            anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
        
        self.voyage_client = voyageai.Client(api_key=voyage_api_key)
        self.anthropic_client = anthropic.Anthropic(api_key=anthropic_api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/contextual_vector_db.pkl"

    def load_data(self, dataset: List[Dict[str, Any]]):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts_to_embed = []
        metadata = []
        total_chunks = sum(len(doc['chunks']) for doc in dataset)
        
        total_input_tokens = 0
        total_output_tokens = 0
        total_cache_read_tokens = 0
        total_cache_creation_tokens = 0
        
        with tqdm(total=total_chunks, desc="Processing chunks") as pbar:
            for doc in dataset:
                doc_content = doc['content']
                for chunk in doc['chunks']:
                    contextualized_text, usage = self.situate_context(doc_content, chunk['content'])
                    texts_to_embed.append(f"{chunk['content']}\n\n{contextualized_text}")
                    metadata.append({
                        'doc_id': doc['doc_id'],
                        'original_uuid': doc['original_uuid'],
                        'chunk_id': chunk['chunk_id'],
                        'original_index': chunk['original_index'],
                        'original_content': chunk['content'],
                        'contextualized_content': contextualized_text
                    })
                    total_input_tokens += usage.input_tokens
                    total_output_tokens += usage.output_tokens
                    total_cache_read_tokens += usage.cache_read_input_tokens
                    total_cache_creation_tokens += usage.cache_creation_input_tokens
                    pbar.update(1)

        self._embed_and_store(texts_to_embed, metadata)
        self.save_db()

        print(f"Contextual Vector database loaded and saved. Total chunks processed: {len(texts_to_embed)}")
        print(f"Total input tokens without caching: {total_input_tokens}")
        print(f"Total output tokens: {total_output_tokens}")
        print(f"Total cache creation tokens: {total_cache_creation_tokens}")
        print(f"Total cache read tokens: {total_cache_read_tokens}")
        
        # Calculate and print the savings from prompt caching
        savings_percentage = (1 - (total_input_tokens / (total_cache_read_tokens + total_input_tokens))) * 100
        print(f"Total input token savings from prompt caching: You used {savings_percentage:.2f}% fewer tokens by using prompt caching.")

    def situate_context(self, doc: str, chunk: str) -> tuple[str, Any]:
        DOCUMENT_CONTEXT_PROMPT = """
        <document>
        {doc_content}
        </document>
        """

        CHUNK_CONTEXT_PROMPT = """
        Here is the chunk we want to situate within the whole document
        <chunk>
        {chunk_content}
        </chunk>

        Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk.
        Answer only with the succinct context and nothing else.
        """

        response = self.anthropic_client.beta.prompt_caching.messages.create(
            model="claude-3-haiku-20240307",
            max_tokens=1024,
            temperature=0.0,
            messages=[
                {
                    "role": "user", 
                    "content": [
                        {
                            "type": "text",
                            "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                            "cache_control": {"type": "ephemeral"}
                        },
                        {
                            "type": "text",
                            "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                        }
                    ]
                }
            ],
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
        )
        return response.content[0].text, response.usage

    def _embed_and_store(self, texts: List[str], data: List[Dict[str, Any]]):
        batch_size = 128
        result = [
            self.voyage_client.embed(
                texts[i : i + batch_size],
                model="voyage-2"
            ).embeddings
            for i in range(0, len(texts), batch_size)
        ]
        self.embeddings = [embedding for batch in result for embedding in batch]
        self.metadata = data

    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.voyage_client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1][:k]
        
        top_results = []
        for idx in top_indices:
            result = {
                "metadata": self.metadata[idx],
                "similarity": float(similarities[idx]),
            }
            top_results.append(result)
        return top_results

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

In [11]:
# Load the transformed dataset
with open('data/codebase_chunks.json', 'r') as f:
    transformed_dataset = json.load(f)

# Initialize the ContextualVectorDB
contextual_db = ContextualVectorDB("my_contextual_db")

# Load and process the data
contextual_db.load_data(transformed_dataset)

Processing chunks: 100%|██████████| 737/737 [13:49<00:00,  1.13s/it]


Contextual Vector database loaded and saved. Total chunks processed: 737
Total input tokens without caching: 500383
Total output tokens: 40273
Total cache creation tokens: 112961
Total cache read tokens: 3053534
Savings from prompt caching: 85.92%


In [12]:
evaluate_db(contextual_db, 'data/evaluation_set.jsonl', 10)

Evaluating retrieval: 100%|██████████| 248/248 [01:10<00:00,  3.52it/s]

Pass@10: 91.13%
Correct retrievals: 226
Total queries: 248





{'pass_at_n': 91.12903225806451,
 'correct_retrievals': 226,
 'total_queries': 248}