# Cohere Rerank

[Cohere Webpage](https://cohere.ai/)

In [191]:
import os
import chromadb
import cohere
from dotenv import load_dotenv, find_dotenv
from langchain_chroma import Chroma
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_openai import OpenAIEmbeddings
from langchain_cohere import CohereRerank

load_dotenv(find_dotenv())

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
DATABASE_PATH = os.getenv('DATABASE_PATH')

In [192]:
EMBEDDING_MODEL = 'text-embedding-ada-002'
COHERE_RERANK_MODEL = 'rerank-multilingual-v3.0'

database_path_basic = os.path.join(DATABASE_PATH, 'Unstructured', 'basic', f"{EMBEDDING_MODEL}")
database_path_title = os.path.join(DATABASE_PATH, 'Unstructured', 'by_title', f"{EMBEDDING_MODEL}")

In [193]:
QUERY = "Wie kann man eine Auskunftspflicht in einer Haushaltsgemeinschaft durchsetzen?"

In [194]:
def pretty_output_text(text: str, words_per_line: int = 10) -> str:
    text_parts = text.split('\n')
    pretty_text = ''
    
    for text_part in text_parts:
        words = text_part.split(' ')
        for i, word in enumerate(words):
            pretty_text += word + ' '
            if (i + 1) % words_per_line == 0 and i != len(words) - 1:
                pretty_text += '\n'
        pretty_text += '\n'
    
    return pretty_text


def pretty_output_docs(docs: list, show_metadata=True, show_full_path=True) -> str:
    print(f"QUERY: {QUERY}")
    print('*' * 150, end='\n\n')
    for i, doc in enumerate(docs):
        print(f"CHUNK #{i+1}:")
        if show_metadata:
            source_path = doc.metadata['source'] if show_full_path else os.path.basename(doc.metadata['source'])
            print(f"Source:\t\t\t{source_path}")
            print(f"Page Number:\t\t{doc.metadata['page_number']}")
            print(f"Relevance Score:\t{doc.metadata['relevance_score']}")
        
        print('-' * 150)
        print(pretty_output_text(doc.page_content, 12))
        print('=' * 150)


def pretty_output_rerankresponse(rerankresponses: list, show_additional_info=True) -> str:
    print(f"QUERY: {QUERY}")
    print('*' * 150, end='\n\n')
    for i, reponse in enumerate(rerankresponses):
        print(f"CHUNK #{i+1}:")
        if show_additional_info:
            print(f"Index:\t\t\t{reponse.index}")
            print(f"Relevance Score:\t{reponse.relevance_score}")
        
        print('-' * 150)
        print(pretty_output_text(reponse.document.text, 12))
        print('=' * 150)

## Create Retriever

In [195]:
chroma_client_basic = chromadb.PersistentClient(
    path=database_path_basic,
)
collection_name_basic = 'collection_1500'

chroma_client_title = chromadb.PersistentClient(
    path=database_path_title,
)
collection_name_title = 'collection_1800'

In [196]:
vectorstore_basic = Chroma(
    collection_name=collection_name_basic,
    client=chroma_client_basic,
    embedding_function=OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY),
    create_collection_if_not_exists=False,
)

vectorstore_title = Chroma(
    collection_name=collection_name_title,
    client=chroma_client_title,
    embedding_function=OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY),
    create_collection_if_not_exists=False,
)

### Default Retriever

In [189]:
n_retrieved_docs = 20

In [190]:
default_retriever_basic = vectorstore_basic.as_retriever(
    search_kwargs={
        'k': n_retrieved_docs,
    }
)

default_retriever_title = vectorstore_title.as_retriever(
    search_kwargs={
        'k': n_retrieved_docs,
    }
)

## Using LangChain Integration

In [197]:
compressor = CohereRerank(
    top_n=3,
    model=COHERE_RERANK_MODEL,
    cohere_api_key=COHERE_API_KEY,
)

In [198]:
compressor_retriever_basic = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=default_retriever_basic,
)

compressor_retriever_title = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=default_retriever_title,
)

In [199]:
compressed_docs_basic = compressor_retriever_basic.invoke(QUERY)
compressed_docs_title = compressor_retriever_title.invoke(QUERY)

In [204]:
# compressed_docs_basic

In [205]:
# pretty_output_docs(compressed_docs_basic, show_full_path=False)

In [152]:
# compressed_docs_title

In [206]:
# pretty_output_docs(compressed_docs_title, show_full_path=False)

## Using Cohere API

Result got an additional paramter called `index`, that represents the index of the document in the retriever list.  
Example:
- if we have in the retriever a list of 20 documents, and we reranke the documents and return the top 5, the index of the 5 documents can be 0 - 19.
- If a document in the reranked list has an index of 5, it means that the document was the 6th document in the original retriever list.

This is nice to see how the documents were reranked.

In [155]:
cohere = cohere.Client(
    api_key=COHERE_API_KEY,
)

In [158]:
docs_basic = default_retriever_basic.invoke(QUERY)
docs_title = default_retriever_title.invoke(QUERY)

In [161]:
# extract only text, cause the cohere.rerank method can not process langchain.Document objects

list_text_docs_basic = [doc.page_content for doc in docs_basic]
list_text_docs_title = [doc.page_content for doc in docs_title]

In [172]:
compressed_docs_basic = cohere.rerank(
    model=COHERE_RERANK_MODEL,
    query=QUERY,
    top_n=3,
    return_documents=True,
    documents=list_text_docs_basic,
).results

compressed_docs_title = cohere.rerank(
    model=COHERE_RERANK_MODEL,
    query=QUERY,
    top_n=3,
    return_documents=True,
    documents=list_text_docs_title,
).results

In [184]:
# compressed_docs_basic

In [185]:
# pretty_output_rerankresponse(compressed_docs_basic)

In [186]:
# compressed_docs_title

In [188]:
# pretty_output_rerankresponse(compressed_docs_title)