# FlashRerank

In [122]:
import os
import chromadb
from flashrank import Ranker, RerankRequest
from dotenv import load_dotenv, find_dotenv
from langchain_chroma import Chroma
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_openai import OpenAIEmbeddings
from langchain.retrievers.document_compressors.flashrank_rerank import FlashrankRerank

load_dotenv(find_dotenv())

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
DATABASE_PATH = os.getenv('DATABASE_PATH')

In [123]:
EMBEDDING_MODEL = 'text-embedding-ada-002'
FLASHRANK_RERANK_MODEL = 'rank-T5-flan'

database_path_basic = os.path.join(DATABASE_PATH, 'Unstructured', 'basic', f"{EMBEDDING_MODEL}")
database_path_title = os.path.join(DATABASE_PATH, 'Unstructured', 'by_title', f"{EMBEDDING_MODEL}")

In [124]:
QUERY = "Wie kann man eine Auskunftspflicht in einer Haushaltsgemeinschaft durchsetzen?"

In [125]:
def pretty_output_text(text: str, words_per_line: int = 10) -> str:
    text_parts = text.split('\n')
    pretty_text = ''
    
    for text_part in text_parts:
        words = text_part.split(' ')
        for i, word in enumerate(words):
            pretty_text += word + ' '
            if (i + 1) % words_per_line == 0 and i != len(words) - 1:
                pretty_text += '\n'
        pretty_text += '\n'
    
    return pretty_text


def pretty_output_docs(docs: list, show_metadata=True, show_full_path=True) -> str:
    print(f"QUERY: {QUERY}")
    print('*' * 150, end='\n\n')
    for i, doc in enumerate(docs):
        print(f"CHUNK #{i+1}:")
        if show_metadata:
            source_path = doc.metadata['source'] if show_full_path else os.path.basename(doc.metadata['source'])
            print(f"Source:\t\t\t{source_path}")
            print(f"Page Number:\t\t{doc.metadata['page_number']}")
            print(f"Idx in Retrieving:\t{doc.metadata['id']}")
            print(f"Relevance Score:\t{doc.metadata['relevance_score']}")
        
        print('-' * 150)
        print(pretty_output_text(doc.page_content, 12))
        print('=' * 150)


def pretty_output_flashrank_results(flashrank_results: list, show_additional_info=True, show_full_path=True) -> str:
    print(f"QUERY: {QUERY}")
    print('*' * 150, end='\n\n')
    for i, result in enumerate(flashrank_results):
        print(f"CHUNK #{i+1}:")
        if show_additional_info:
            source_path = result['metadata']['source'] if show_full_path else os.path.basename(result['metadata']['source'])
            print(f"Source:\t\t\t{source_path}")
            print(f"Page Number:\t\t{result['metadata']['page_number']}")
            print(f"Idx in Retrieving:\t{result['id']}")
            print(f"Score:\t{result['score']}")
        
        print('-' * 150)
        print(pretty_output_text(result['text'], 12))
        print('=' * 150)

## Create Retriever

In [126]:
chroma_client_basic = chromadb.PersistentClient(
    path=database_path_basic,
)
collection_name_basic = 'collection_1500'

chroma_client_title = chromadb.PersistentClient(
    path=database_path_title,
)
collection_name_title = 'collection_1800'

In [127]:
vectorstore_basic = Chroma(
    collection_name=collection_name_basic,
    client=chroma_client_basic,
    embedding_function=OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY),
    create_collection_if_not_exists=False,
)

vectorstore_title = Chroma(
    collection_name=collection_name_title,
    client=chroma_client_title,
    embedding_function=OpenAIEmbeddings(model=EMBEDDING_MODEL, api_key=OPENAI_API_KEY),
    create_collection_if_not_exists=False,
)

### Default Retriever

In [128]:
n_retrieved_docs = 20

In [129]:
default_retriever_basic = vectorstore_basic.as_retriever(
    search_kwargs={
        'k': n_retrieved_docs,
    }
)

default_retriever_title = vectorstore_title.as_retriever(
    search_kwargs={
        'k': n_retrieved_docs,
    }
)

## Using LangChain Intergration

- `FlashrankRerank()` is using `cross-encoder/ms-marco-MultiBERT-L-12` as default reranker. 

-  This model is trained on [MS MARCO dataset](https://microsoft.github.io/msmarco/).  

- Watch [GitHub](https://github.com/PrithivirajDamodaran/FlashRank) for more details.

- Following models are available for reranking:
    ```python
    # flashrank_rerank.py -> Ranker.py -> Config.py
    # Config.py
    {
        "ms-marco-TinyBERT-L-2-v2": "flashrank-TinyBERT-L-2-v2.onnx",
        "ms-marco-MiniLM-L-12-v2": "flashrank-MiniLM-L-12-v2_Q.onnx",
        "ms-marco-MultiBERT-L-12": "flashrank-MultiBERT-L12_Q.onnx",
        "rank-T5-flan": "flashrank-rankt5_Q.onnx",
        "ce-esci-MiniLM-L12-v2": "flashrank-ce-esci-MiniLM-L12-v2_Q.onnx",
        "rank_zephyr_7b_v1_full": "rank_zephyr_7b_v1_full.Q4_K_M.gguf",
        "miniReranker_arabic_v1": "miniReranker_arabic_v1.onnx"
    }
    ```

- **Cross-Endoder** models can only process **512 tokens**.

- **LLM-based** models can process **8096 tokens**.


**From [GitHub](https://github.com/PrithivirajDamodaran/FlashRank):**
| Model Name | Description | Size | Notes |
|------------|-------------|------|-------|
| `ms-marco-TinyBERT-L-2-v2` | - | ~4MB | [Model card](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) |
| `ms-marco-MiniLM-L-12-v2` | - | ~34MB | [Model card](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) |
| `rank-T5-flan` | Best non cross-encoder reranker | ~110MB | [Model card](https://huggingface.co/bergum/rank-T5-flan) |
| `ms-marco-MultiBERT-L-12` | Multi-lingual, supports 100+ languages | ~150MB | [Supported languages](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages) |
| `ce-esci-MiniLM-L12-v2` | Fine-tuned on Amazon ESCI dataset | - | [Model card](https://huggingface.co/metarank/ce-esci-MiniLM-L12-v2) |
| `rank_zephyr_7b_v1_full` | 4-bit-quantised GGUF | ~4GB | [Model card](https://huggingface.co/castorini/rank_zephyr_7b_v1_full) |
| `miniReranker_arabic_v1` | - | - | [Model card](https://huggingface.co/prithivida/miniReranker_arabic_v1) |


In [130]:
compressor = FlashrankRerank(
    top_n=5,
    model=FLASHRANK_RERANK_MODEL,
)

In [119]:
compressor_retriever_basic = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=default_retriever_basic,
)

compressor_retriever_title = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=default_retriever_title,
)

In [120]:
compressed_docs_basic = compressor_retriever_basic.invoke(QUERY)
compressed_docs_title = compressor_retriever_title.invoke(QUERY)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [131]:
# compressed_docs_basic

In [73]:
# pretty_output_docs(compressed_docs_basic, show_full_path=False)

In [74]:
# compressed_docs_title

In [75]:
# pretty_output_docs(compressed_docs_title, show_full_path=False)

## Using FlashRank API

- `Ranker()` is using `cross-encoder/ms-marco-TinyBERT-L-2-v2` as default reranker. 

- Watch [GitHub](https://github.com/PrithivirajDamodaran/FlashRank) for more details.

**From [GitHub](https://github.com/PrithivirajDamodaran/FlashRank):**
| Model Name | Description | Size | Notes |
|------------|-------------|------|-------|
| `ms-marco-TinyBERT-L-2-v2` | Default model in API | ~4MB | [Model card](https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2) |
| `ms-marco-MiniLM-L-12-v2` | - | ~34MB | [Model card](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) |
| `rank-T5-flan` | Best non cross-encoder reranker | ~110MB | [Model card](https://huggingface.co/bergum/rank-T5-flan) |
| `ms-marco-MultiBERT-L-12` | Multi-lingual, supports 100+ languages | ~150MB | [Supported languages](https://github.com/google-research/bert/blob/master/multilingual.md#list-of-languages) |
| `ce-esci-MiniLM-L12-v2` | Fine-tuned on Amazon ESCI dataset | - | [Model card](https://huggingface.co/metarank/ce-esci-MiniLM-L12-v2) |
| `rank_zephyr_7b_v1_full` | 4-bit-quantised GGUF | ~4GB | [Model card](https://huggingface.co/castorini/rank_zephyr_7b_v1_full) |
| `miniReranker_arabic_v1` | - | - | [Model card](https://huggingface.co/prithivida/miniReranker_arabic_v1) |

- **Cross-Endoder** models can only process **512 tokens**.

- **LLM-based** models can process **8096 tokens**.

The API is not compressing the documents, so if you retrieve 20 documents from database, the API will rerank all that 20 documents and **returning all back**.  
You can just slice the top-k documents from the list (*Thats the same way that LangChain is doing*).

In [104]:
ranker = Ranker(
    model_name=FLASHRANK_RERANK_MODEL,
    max_length=4096,
)

In [105]:
docs_basic = default_retriever_basic.invoke(QUERY)
docs_title = default_retriever_title.invoke(QUERY)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


In [106]:
docs_list_basic = [{'id': i, 'text': doc.page_content, 'metadata': doc.metadata} for i, doc in enumerate(docs_basic)]
docs_list_title = [{'id': i, 'text': doc.page_content, 'metadata': doc.metadata} for i, doc in enumerate(docs_title)]

In [107]:
top_n = 5

rerankerrequest_basic = RerankRequest(
    query=QUERY,
    passages=docs_list_basic,
)

rerankerrequest_title = RerankRequest(
    query=QUERY,
    passages=docs_list_title,
)

In [108]:
results_basic = ranker.rerank(rerankerrequest_basic)
results_title = ranker.rerank(rerankerrequest_title)

In [116]:
# results_basic[:top_n]

In [115]:
# pretty_output_flashrank_results(results_basic[:top_n], show_full_path=False)

In [117]:
# results_title[:top_n]

In [114]:
# pretty_output_flashrank_results(results_title[:top_n], show_full_path=False)