# Implementing Prompt Caching with LangChain and Amazon Bedrock Converse API

This notebook demonstrates how to effectively use prompt caching with LangChain's ChatBedrockConverse class to improve performance when working with large documents or repetitive contexts.

## What is Prompt Caching?

Prompt caching allows you to store portions of your conversation context, enabling models to:
- Reuse cached context instead of reprocessing inputs
- Reduce response Time-To-First-Token (TTFT) for subsequent queries
- Potentially lower token usage by avoiding redundant processing

This is particularly useful for scenarios like:
- Chat with documents (RAG applications)
- Coding assistants with large code files
- Agentic workflows with complex system prompts
- Few-shot learning with numerous examples

## Setup

First, let's install the required packages:

In [None]:
# Install required packages
!pip install -U langchain-aws boto3 pandas matplotlib seaborn requests

## Import Dependencies

In [None]:
# Standard libraries
import json
import time
from enum import Enum

# AWS and external services
import boto3
import requests

# Data processing and visualization
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import seaborn as sns
import numpy as np

# LangChain components
from langchain_aws import ChatBedrockConverse
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

## Initialize the LLM

Let's set up our ChatBedrockConverse model with appropriate configuration for prompt caching:

In [None]:
# Initialize the ChatBedrockConverse model
llm = ChatBedrockConverse(
    model_id="anthropic.claude-3-5-haiku-20241022-v1:0",
    region_name="us-west-2",
    temperature=0,  # Lower temperature for more deterministic responses
    max_tokens=1000  # Adjust based on your needs
)

# Test the model with a simple query
test_response = llm.invoke("Hello, are you ready to demonstrate prompt caching?")
print("Model initialized successfully. Test response:")
print(test_response.content)

## Fetch Sample Documents

To effectively demonstrate prompt caching, we need documents with sufficient length. Let's fetch some AWS blog posts:

In [None]:
# URLs for sample documents
topics = [
    'https://aws.amazon.com/blogs/aws/reduce-costs-and-latency-with-amazon-bedrock-intelligent-prompt-routing-and-prompt-caching-preview/',
    'https://aws.amazon.com/blogs/machine-learning/enhance-conversational-ai-with-advanced-routing-techniques-with-amazon-bedrock/',
    'https://aws.amazon.com/blogs/security/cost-considerations-and-common-options-for-aws-network-firewall-log-management/'
]

# Fetch the first document
response = requests.get(topics[0])
blog = response.text

# Print a preview of the document
print(f"Document length: {len(blog)} characters")
print(f"Preview: {blog[:200]}...")

## Define Helper Functions

Let's create helper functions to work with prompt caching and measure performance:

In [None]:
class CacheMode(Enum):
    ON = "on"
    OFF = "off"
    
    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return NotImplemented


def chat_with_document_langchain(document, user_query, llm_model, use_cache=True):
    """Chat with a document using LangChain's ChatBedrockConverse with proper prompt caching."""
    
    # Create system message with instructions
    instructions = (
        "I will provide you with a document, followed by a question about its content. "
        "Your task is to analyze the document, extract relevant information, and provide "
        "a comprehensive answer to the question."
    )
    
    document_content = f"Here is the document: <document> {document} </document>"
    
    # Start timing
    start_time = time.time()
    
    # Create messages with cache point if enabled
    messages = [
        SystemMessage(content=instructions),
    ]
    
    # Add document content with cache point if caching is enabled
    if use_cache:
        # This is the key part - add the cache point directly in the message content
        human_message_content = [
            {"type": "text", "text": document_content},
            ChatBedrockConverse.create_cache_point()  # Add cache point here
        ]
        messages.append(HumanMessage(content=human_message_content))
    else:
        messages.append(HumanMessage(content=document_content))
    
    # First invoke to process the document (and cache it if enabled)
    response = llm_model.invoke(messages)
    
    # Now add the user query
    messages.append(HumanMessage(content=user_query))
    
    # Second invoke with the query
    response = llm_model.invoke(messages)
    
    # Calculate elapsed time
    elapsed_time = time.time() - start_time
    
    # Print results
    print(f"Response (elapsed time: {elapsed_time:.2f}s):")
    print(response.content)
    
    # Print usage metrics if available
    if hasattr(response, 'usage_metadata') and response.usage_metadata:
        print("\nUsage metrics:")
        print(json.dumps(response.usage_metadata, indent=2))

        # Check for cache-related metrics in input_token_details
        if hasattr(response.usage_metadata, 'input_token_details'):
            cache_details = response.usage_metadata.input_token_details
            if cache_details.get('cache_read', 0) > 0:
                print(f"Cache was used! Read tokens: {cache_details['cache_read']}")
            if cache_details.get('cache_creation', 0) > 0:
                print(f"Cache was created! Write tokens: {cache_details['cache_creation']}")
    
    return response, elapsed_time


def add_median_labels(ax):
    """Add median value labels to a boxplot."""
    lines = ax.get_lines()
    boxes = [c for c in ax.get_children() if type(c).__name__ == 'PathPatch']
    lines_per_box = int(len(lines) / len(boxes))
    for median in lines[4:len(lines):lines_per_box]:
        x, y = (data.mean() for data in median.get_data())
        # get text value from the median line
        value = median.get_ydata()[0]
        text = ax.text(x, y, f'{value:.2f}s', ha='center', va='center',
                      fontweight='bold', color='white')
        text.set_path_effects([path_effects.withStroke(linewidth=3, foreground='black')])

## Test Document Chat with Caching

Now let's test our document chat function with prompt caching enabled. The key difference is that we're including the cache point directly in the message content:

In [None]:
# Sample questions
questions = [
    'What is this blog post about?',
    'What are the main use cases for prompt caching?',
    'How does prompt caching improve performance?'
]

# First query with caching enabled (this will create the cache)
print("FIRST QUERY (CACHE CREATION):")
print("-" * 50)
response1, time1 = chat_with_document_langchain(blog, questions[0], llm, use_cache=True)

In [None]:
# Second query with caching enabled (this should use the cache)
print("\n\nSECOND QUERY (USING CACHE):")
print("-" * 50)
response2, time2 = chat_with_document_langchain(blog, questions[1], llm, use_cache=True)

In [None]:
# Third query with caching disabled (for comparison)
print("\n\nTHIRD QUERY (NO CACHE):")
print("-" * 50)
response3, time3 = chat_with_document_langchain(blog, questions[2], llm, use_cache=False)

## Benchmark Function

Let's create a function to benchmark the performance of prompt caching:

In [None]:
from time import sleep
def benchmark_prompt_caching(document, questions, llm_model, iterations=3):
    """Benchmark the performance of prompt caching."""
    results = []
    
    # Test with caching enabled
    print("\nBenchmarking with caching ENABLED:")
    for i in range(iterations):
        for q_idx, question in enumerate(questions):
            print(f"Iteration {i+1}, Question {q_idx+1}: {question[:30]}...")
            start_time = time.time()
            response, _ = chat_with_document_langchain(document, question, llm_model, use_cache=True)
            elapsed = time.time() - start_time
            results.append({
                'cache_mode': CacheMode.ON.value,
                'iteration': i+1,
                'question_idx': q_idx+1,
                'time': elapsed
            })
            print(f"Time: {elapsed:.2f}s\n")
    
    #sleep between tests
    sleep(60)

    # Test with caching disabled
    print("\nBenchmarking with caching DISABLED:")
    for i in range(iterations):
        for q_idx, question in enumerate(questions):
            print(f"Iteration {i+1}, Question {q_idx+1}: {question[:30]}...")
            start_time = time.time()
            response, _ = chat_with_document_langchain(document, question, llm_model, use_cache=False)
            elapsed = time.time() - start_time
            results.append({
                'cache_mode': CacheMode.OFF.value,
                'iteration': i+1,
                'question_idx': q_idx+1,
                'time': elapsed
            })
            print(f"Time: {elapsed:.2f}s\n")
    
    # Convert to DataFrame for analysis
    return pd.DataFrame(results)

## Run Benchmark

Let's run a more systematic benchmark to measure the performance improvements from prompt caching:

In [None]:
# Run the benchmark with a smaller number of iterations for demonstration
benchmark_results = benchmark_prompt_caching(
    document=blog,
    questions=questions,
    llm_model=llm,
    iterations=1  # Adjust based on your needs
)

# Display the results
benchmark_results

## Visualize Benchmark Results

In [None]:
# Set the style
sns.set_style("whitegrid")

# Create the plot
plt.figure(figsize=(10, 6))
ax = sns.boxplot(x='cache_mode', y='time', data=benchmark_results)

# Add median labels
add_median_labels(ax)

# Set titles and labels
plt.title('Response Time by Cache Mode', fontsize=16)
plt.xlabel('Cache Mode', fontsize=14)
plt.ylabel('Time (seconds)', fontsize=14)

# Show the plot
plt.tight_layout()
plt.show()

## Using Prompt Caching with LangChain Chains

Now let's see how to integrate prompt caching with LangChain chains:

In [None]:
def create_chain_with_caching(llm_model, document):
    """Create a LangChain chain with prompt caching."""
    # Create a prompt template with the document and a cache point
    template = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant that answers questions about documents."),
        ("human", [{"type": "text", "text": f"Here is a document: {document[:2000]}..."},
                  ChatBedrockConverse.create_cache_point()]),
        ("human", "{question}")
    ])
    
    # Create the chain
    chain = template | llm_model | StrOutputParser()
    
    return chain


def run_chain_with_timing(chain, document, question):
    """Run a chain with timing."""
    start_time = time.time()
    response = chain.invoke({"question": question})
    elapsed_time = time.time() - start_time
    
    print(f"\nQuestion: {question}")
    print(f"Response (elapsed time: {elapsed_time:.2f}s):")
    print(response[:200] + "..." if len(response) > 200 else response)
    
    return elapsed_time

In [None]:
# Create a chain with prompt caching
chain = create_chain_with_caching(llm, blog)

In [None]:
# Test the chain with multiple queries
print("Running chain with prompt caching:")
time1 = run_chain_with_timing(chain, blog, questions[0])
time2 = run_chain_with_timing(chain, blog, questions[1])
time3 = run_chain_with_timing(chain, blog, questions[2])

print(f"First query time: {time1:.2f}s")
print(f"Second query time: {time2:.2f}s")
print(f"Third query time: {time3:.2f}s")

## Manual Example: Direct Message Construction

Let's look at a more manual example where we construct the messages directly:

In [None]:
# Create messages with cache point
messages = [
    SystemMessage(content="You are a helpful assistant that answers questions about documents."),
    HumanMessage(content=[
        {"type": "text", "text": f"Here is a document: {blog[:1000]}..."},
        ChatBedrockConverse.create_cache_point()  # Add cache point here
    ])
]

# First invoke to process the document (and cache it)
start_time = time.time()
response = llm.invoke(messages)
print(f"First response time (cache creation): {time.time() - start_time:.2f}s")

# Add a question
messages.append(HumanMessage(content="What is the main topic of this document?"))

# Second invoke with the question (should use cache)
start_time = time.time()
response = llm.invoke(messages)
print(f"Second response time (using cache): {time.time() - start_time:.2f}s")
print(f"Response: {response.content}")

## Best Practices for Prompt Caching with LangChain

Here are some best practices for using prompt caching with LangChain:

1. **Include Cache Point in Message Content**: The cache point must be included directly in the message content as a special content block, not just as part of the configuration.

2. **Place Cache Point After Static Content**: Place the cache point after the static content (like documents or system prompts) that you want to cache.

3. **Use Consistent Cache Points**: Use the same cache point type for related requests to ensure proper caching.

4. **Monitor Cache Metrics**: Check the `usage_metadata.input_token_details` field to confirm that caching is working as expected.

5. **Structure Messages Properly**: Separate static content (like documents, system prompts) from dynamic content (user queries) to maximize caching benefits.

6. **Consider Cache Lifetime**: Be aware that cached prompts expire after a period of inactivity (typically 24 hours).

## Conclusion

This notebook demonstrated how to use prompt caching with LangChain's ChatBedrockConverse to improve performance when working with large documents or repetitive contexts. Key takeaways:

- Prompt caching can significantly reduce response times for subsequent queries
- The cache point must be included directly in the message content as a special content block
- The `create_cache_point()` method makes it easy to generate cache configurations
- Caching works well with LangChain's chains and other abstractions
- Performance benefits are most noticeable with large documents or complex system prompts

By leveraging prompt caching in your LangChain applications, you can create more responsive and efficient AI experiences while potentially reducing costs.