# Document retrieval using Cohere Embedding and Rerank


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

---

Retrieval Augmented Generation (RAG) is a process in which the model retrieves contextual documents from an external data source. In this notebook we will demonstrate how to use Cohere Embedding and Rerank for document ingestion and improved retrieval. 

The Cohere Platform empowers enterprises and developers to use Large Language Models (LLMs) privately and securely with AWS JumpStart deployment. We have announced the availability of Cohere’s LLMs through Amazon SageMaker in Jan 2023. Customers can easily subscribe [Cohere’s LLMs through AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=87af0c85-6cf9-4ed8-bee0-b40ce65167e0) and use them in Amazon SageMaker.

[Cohere Embed Model v3 - English](https://aws.amazon.com/marketplace/pp/prodview-qd64mji3pbnvk) allows you to classify, embed, and tokenize text. 
[Cohere Rerank 3 Model - English](https://aws.amazon.com/marketplace/pp/prodview-rqhxjsjanb3gy) allows to rank a given set of documents or chunks. 


1. Model Subscription and Deployment: Instructions are provided for subscribing to Cohere's Rerank and Embed models through AWS Marketplace and deploying them on SageMaker.
2. Endpoint Deployment: Functions are defined to get the model package ARNs for Cohere's Rerank and Embed models. SageMaker endpoints are then created for these models.
3. Document Ingestion: A set of sample documents (customer service queries) is defined and embedded using Cohere's Embed model. The embeddings are then added to a FAISS index for efficient similarity search.

4. Retrieval and Reranking: The notebook demonstrates how to:

    a. Embed a sample question

    b. Retrieve similar documents using the FAISS index

    c. Rerank the retrieved documents using Cohere's Rerank model

    d. Results Comparison: The notebook compares the raw vector search results with the reranked results, showing how reranking can improve the relevance of retrieved documents.

5. Cleanup: Instructions for deleting the SageMaker endpoints are provided.



## Step 1. Subscribe to the model packages and deploy Cohere Rerank model and Embedding model in SageMaker JumpStart

To subscribe to the model packages:

1. Open the model package listing pages [Cohere Rerank 3 Model - English](https://aws.amazon.com/marketplace/pp/prodview-rqhxjsjanb3gy) and [Cohere Embed Model v3 - English](https://aws.amazon.com/marketplace/pp/prodview-qd64mji3pbnvk).
1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.
1. On the **Subscribe to this software** page, review and click on **"Accept Offer"** if you and your organization agrees with EULA, pricing, and support terms.
1. Once you click on **Continue to configuration** button and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell.

In [None]:
!pip install --upgrade sagemaker --quiet
!pip install --upgrade cohere-sagemaker --quiet
!pip install --upgrade cohere-aws --quiet

In [None]:
import sagemaker
import faiss
import numpy as np
from cohere_aws import Client
import boto3

## 2. Deploy the Cohere endpoints

In [9]:
def rerank_model_package_arn():
    """Get the Cohere rerank model package arn"""

    cohere_package = "cohere-rerank-english-v3-6-121-3fed5583784c39d39b65ad0ca878e6b4"
    model_package_map = {
        "us-east-1": f"arn:aws:sagemaker:us-east-1:865070037744:model-package/{cohere_package}",
        "us-east-2": f"arn:aws:sagemaker:us-east-2:057799348421:model-package/{cohere_package}",
        "us-west-1": f"arn:aws:sagemaker:us-west-1:382657785993:model-package/{cohere_package}",
        "us-west-2": f"arn:aws:sagemaker:us-west-2:594846645681:model-package/{cohere_package}",
        "ca-central-1": f"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{cohere_package}",
        "eu-central-1": f"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{cohere_package}",
        "eu-west-1": f"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{cohere_package}",
        "eu-west-2": f"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{cohere_package}",
        "eu-west-3": f"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{cohere_package}",
        "eu-north-1": f"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{cohere_package}",
        "ap-southeast-1": f"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{cohere_package}",
        "ap-southeast-2": f"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{cohere_package}",
        "ap-northeast-2": f"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{cohere_package}",
        "ap-northeast-1": f"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}",
        "ap-south-1": f"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{cohere_package}",
        "sa-east-1": f"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{cohere_package}",
    }

    region = boto3.Session().region_name
    if region not in model_package_map.keys():
        raise Exception(f"{region} UNSUPPORTED REGION")

    return model_package_map[region]


def embedd_model_package_arn():
    # replace the arn below with the model package arn you want to deploy
    cohere_package = "cohere-embed-english-v3-8-1108-7e0739d809803e9380e75b954811a0cb"

    # Mapping for Model Packages
    model_package_map = {
        "us-east-1": f"arn:aws:sagemaker:us-east-1:865070037744:model-package/{cohere_package}",
        "us-east-2": f"arn:aws:sagemaker:us-east-2:057799348421:model-package/{cohere_package}",
        "us-west-1": f"arn:aws:sagemaker:us-west-1:382657785993:model-package/{cohere_package}",
        "us-west-2": f"arn:aws:sagemaker:us-west-2:594846645681:model-package/{cohere_package}",
        "ca-central-1": f"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{cohere_package}",
        "eu-central-1": f"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{cohere_package}",
        "eu-west-1": f"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{cohere_package}",
        "eu-west-2": f"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{cohere_package}",
        "eu-west-3": f"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{cohere_package}",
        "eu-north-1": f"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{cohere_package}",
        "ap-southeast-1": f"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{cohere_package}",
        "ap-southeast-2": f"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{cohere_package}",
        "ap-northeast-2": f"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{cohere_package}",
        "ap-northeast-1": f"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{cohere_package}",
        "ap-south-1": f"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{cohere_package}",
        "sa-east-1": f"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{cohere_package}",
    }
    region = boto3.Session().region_name
    if region not in model_package_map.keys():
        raise Exception(f"{region} UNSUPPORTED REGION")

    return model_package_map[region]

In [33]:
region_name = "us-east-1"
endpoints = {
    "rerank": {
        "name": "cohere-rerank-english-v3-0",
        "model_package_arn": rerank_model_package_arn(),
    },
    "embed": {
        "name": "cohere-embedd-english-v3-0",
        "model_package_arn": embedd_model_package_arn(),
    },
}

In [None]:
# Create SageMaker Endpoints
co = Client(region_name=region)
co.create_endpoint(
    arn=endpoints["rerank"]["model_package_arn"],
    endpoint_name=endpoints["rerank"]["name"],
    instance_type="ml.g5.xlarge",
    n_instances=1,
    role=sagemaker.get_execution_role(),
)
co.create_endpoint(
    arn=endpoints["embed"]["model_package_arn"],
    endpoint_name=endpoints["embed"]["name"],
    instance_type="ml.g4dn.2xlarge",
    n_instances=1,
    role=sagemaker.get_execution_role(),
)

In [31]:
# The Cohere client seems to lack the ability to handle different endpoints.
# Hence, we wrap the required functions


def rerank(**kwargs):
    co.connect_to_endpoint(endpoint_name=endpoints["rerank"]["name"])
    return co.rerank(**kwargs)


def embed(**kwargs):
    co.connect_to_endpoint(endpoint_name=endpoints["embed"]["name"])
    return co.embed(**kwargs)

## 3. Ingest the documents into the vector store
In this section, the embed a set of documents into vector store hold in memory using the `faiss` index. This will alow to perform a retrieval as it is done in a RAG application. 

In [28]:
documents = [
    {
        "Title": "Incorrect Password",
        "Content": "Hello, I have been trying to access my account for the past hour and it keeps saying my password is incorrect. Can you please help me?",
    },
    {
        "Title": "Confirmation Email Missed",
        "Content": "Hi, I recently purchased a product from your website but I never received a confirmation email. Can you please look into this for me?",
    },
    {
        "Title": "Questions about Return Policy",
        "Content": "Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.",
    },
    {
        "Title": "Customer Support is Busy",
        "Content": "Good morning, I have been trying to reach your customer support team for the past week but I keep getting a busy signal. Can you please help me?",
    },
    {
        "Title": "Received Wrong Item",
        "Content": "Hi, I have a question about my recent order. I received the wrong item and I need to return it.",
    },
    {
        "Title": "Customer Service is Unavailable",
        "Content": "Hello, I have been trying to reach your customer support team for the past hour but I keep getting a busy signal. Can you please help me?",
    },
    {
        "Title": "Return Policy for Defective Product",
        "Content": "Hi, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.",
    },
    {
        "Title": "Wrong Item Received",
        "Content": "Good morning, I have a question about my recent order. I received the wrong item and I need to return it.",
    },
    {
        "Title": "Return Defective Product",
        "Content": "Hello, I have a question about the return policy for this product. I purchased it a few weeks ago and it is defective.",
    },
    {
        "Title": "Shipping Delay",
        "Content": "I placed an order last week and it still hasn't shipped. Can you tell me what's causing the delay?",
    },
    {
        "Title": "Refund Request",
        "Content": "I returned an item two weeks ago but haven't received my refund yet. When can I expect it?",
    },
    {
        "Title": "Product Availability",
        "Content": "Is the XYZ model back in stock? I've been waiting for it to be available.",
    },
    {
        "Title": "Coupon Code Not Working",
        "Content": "I'm trying to use the SUMMER20 coupon code at checkout, but it's not applying. Is it still valid?",
    },
    {
        "Title": "Account Locked",
        "Content": "My account has been locked due to too many login attempts. How can I unlock it?",
    },
    {
        "Title": "Damaged Package",
        "Content": "My order arrived today, but the package was damaged. What should I do?",
    },
    {
        "Title": "Cancel Subscription",
        "Content": "I want to cancel my monthly subscription. Can you guide me through the process?",
    },
    {
        "Title": "Missing Part",
        "Content": "The product I received is missing a crucial part. Can you send me the missing piece?",
    },
    {
        "Title": "Change Delivery Address",
        "Content": "I need to change the delivery address for my upcoming order. Is that possible?",
    },
    {
        "Title": "Price Match Request",
        "Content": "I found the same product cheaper on another website. Do you offer price matching?",
    },
    {
        "Title": "Gift Card Balance",
        "Content": "How can I check the remaining balance on my gift card?",
    },
]

In [54]:
retrieval_top_k = 5

input_type = "search_document"
embed_response = embed(
    texts=[d["Content"] for d in documents], input_type=input_type, truncate="END"
)
embeddings = embed_response.embeddings

In [55]:
index = faiss.IndexFlatL2(len(embeddings[0]))
index.reset()  # Clear any pre-existing index
index.add(np.array(embeddings, dtype=np.float32))

## 4. Retrieve and rerank 
In this step, we take a question and retrieve the relevant documents in the vector store. Subsequently, we rerank the retrieved documents. 
Note: we use the `input_type` "search_query" to generate the embedding of the sample question as recommend by Cohere which already returns well ranked documents. However, if you use "search_document" as `input_type` the rerank effect becomes visible quicky. For more details, please visit https://cohere.com/blog/introducing-embed-v3.

In [78]:
question = "What emails have been about returning items?"

input_type = "search_query"  # see https://cohere.com/blog/introducing-embed-v3
query_embedding = embed(texts=[question], input_type=input_type, truncate="END")

In [79]:
distances, result = index.search(
    np.array(query_embedding.embeddings[0], dtype=np.float32).reshape(1, -1), k=retrieval_top_k
)

In [80]:
response = rerank(
    documents=[documents[i] for i in result.flatten()],
    query=question,
    rank_fields=["Title", "Content"],
    top_n=retrieval_top_k,
)

Print ranked versus un ranked results: 

In [None]:
print("Raw vector search results:")
for i, k in enumerate(result.flatten()):
    print(f"Retrieval rank: {i}, title: {documents[k]['Title']}")


print("Reranked results:")
for res in response:
    print(f"Retrieval rank: {res.index}, title: {res.document['Title']}")

## 5. Clean up of the SageMaker endpoints
In this section, we are going to delete the previously created SageMaker endpoints to avoid generating unnecessary costs

In [None]:
for _, endpoint in endpoints.items():
    co.connect_to_endpoint(endpoint_name=endpoint["name"])
    co.delete_endpoint()
co.close()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|question_answering_retrieval_augmented_generation|question_answering_Cohere+langchain_jumpstart.ipynb)
