Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ Welcome to `rerankers`! Our goal is to provide users with a simple API to use an

## Updates

- v0.2.0: 🆕 [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
- v0.3.0: 🆕 Many changes! Experimental support for RankLLM, directly backed by the [rank-llm library](https://github.com/castorini/rank_llm). A new `Document` object, courtesy of joint-work by [@bclavie](https://github.com/bclavie) and [Anmol6](https://github.com/Anmol6). This object is transparent, but now offers support for `metadata` stored alongside each document. Many small QoL changes (RankedResults can be itered on directly...)
- v0.2.0: [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers, Basic async support thanks to [@tarunamasa](https://github.com/tarunamasa), MixedBread.ai reranking API
- v0.1.2: Voyage reranking API
- v0.1.1: Langchain integration fixed!
- v0.1.0: Initial release
Expand Down Expand Up @@ -59,6 +60,9 @@ pip install "rerankers[api]"
# FlashRank rerankers (ONNX-optimised, very fast on CPU)
pip install "rerankers[fastrank]"

# RankLLM rerankers (better RankGPT + support for local models such as RankZephyr and RankVicuna)
pip install "rerankers[rankllm]"

# All of the above
pip install "rerankers[all]"
```
Expand Down Expand Up @@ -105,12 +109,27 @@ ranker = Reranker("rankgpt3", api_key = API_KEY)
# RankGPT with another LLM provider
ranker = Reranker("MY_LLM_NAME" (check litellm docs), model_type = "rankgpt", api_key = API_KEY)

# RankLLM with default GPT (GPT-4o)
ranker = Reranker("rankllm", api_key = API_KEY)

# RankLLM with specified GPT models
ranker = Reranker('gpt-4-turbo', model_type="rankllm", api_key = API_KEY)

# EXPERIMENTAL: RankLLM with RankZephyr
ranker = Reranker('rankzephyr')

# ColBERTv2 reranker
ranker = Reranker("colbert")

# ... Or a non-default colbert model:
ranker = Reranker(model_name_or_path, model_type = "colbert")

# Flashrank
ranker = Reranker('flashrank')

# ... Or a specific model
ranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')

```

_Rerankers will always try to infer the model you're trying to use based on its name, but it's always safer to pass a `model_type` argument to it if you can!_
Expand Down Expand Up @@ -180,18 +199,18 @@ Legend:

Models:
- ✅ Any standard SentenceTransformer or Transformers cross-encoder
- 🟠 RankGPT (Implemented using original repo, but missing the rankllm's repo improvements)
- RankGPT (Available both via the original RankGPT implementation and the improved RankLLM one)
- ✅ T5-based pointwise rankers (InRanker, MonoT5...)
- ✅ Cohere, Jina, Voyage and MixedBread API rerankers
- ✅ [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) rerankers (ONNX-optimised models, very fast on CPU)
- 🟠 ColBERT-based reranker - not a model initially designed for reranking, but quite strong (Implementation could be optimised and is from a third-party implementation.)
- 📍 MixedBread API (Reranking API not yet released)
- 📍⭐ RankLLM/RankZephyr (Proper RankLLM implementation will replace the RankGPT one, and introduce RankZephyr support)
- 🟠⭐ RankLLM/RankZephyr: supported by wrapping the [rank-llm library](https://github.com/castorini/rank_llm) library! Support for RankZephyr/RankVicuna is untested, but RankLLM + GPT models fully works!
- 📍 LiT5

Features:
- ✅ Metadata!
- ✅ Reranking
- ✅ Consistency notebooks to ensure performance on `scifact` matches the litterature for any given model implementation (Except RankGPT, where results are harder to reproduce).
- ✅ ONNX runtime support --> Offered through [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) -- in line with the philosophy of the lib, we won't reinvent the wheel when @PrithivirajDamodaran is doing amazing work!
- 📍 Training on Python >=3.10 (via interfacing with other libraries)
- 📍 ONNX runtime support --> Unlikely to be immediate
- ❌(📍Maybe?) Training via rerankers directly
395 changes: 259 additions & 136 deletions examples/overview.ipynb

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ packages = [
name = "rerankers"


version = "0.2.0"
version = "0.3.0"

description = "A unified API for various document re-ranking models."

Expand Down Expand Up @@ -52,12 +52,13 @@ dependencies = [
]

[project.optional-dependencies]
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf", "flashrank"]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
all = ["transformers", "torch", "litellm", "requests", "sentencepiece", "protobuf", "flashrank", "rank-llm"]
transformers = ["transformers", "torch", "sentencepiece", "protobuf"]
api = ["requests"]
gpt = ["litellm"]
flashrank = ["flashrank"]
rankllm = ["rank-llm"]
dev = ["ruff", "isort", "pytest", "ipyprogress", "ipython", "ranx", "ir_datasets", "srsly"]

[project.urls]
"Homepage" = "https://github.com/bclavie/rerankers"
"Homepage" = "https://github.com/answerdotai/rerankers"
2 changes: 1 addition & 1 deletion rerankers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from rerankers.documents import Document

__all__ = ["Reranker", "Document"]
__version__ = "0.2.0"
__version__ = "0.3.0"
7 changes: 7 additions & 0 deletions rerankers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@
AVAILABLE_RANKERS["FlashRankRanker"] = FlashRankRanker
except ImportError:
pass

try:
from rerankers.models.rankllm_ranker import RankLLMRanker

AVAILABLE_RANKERS["RankLLMRanker"] = RankLLMRanker
except ImportError:
pass
22 changes: 10 additions & 12 deletions rerankers/models/flashrank_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@


from typing import Union, List, Optional, Tuple
from rerankers.utils import (
vprint,
ensure_docids,
ensure_docs_list,
)
from rerankers.utils import vprint, prep_docs
from rerankers.results import RankedResults, Result
from rerankers.documents import Document


class FlashRankRanker(BaseRanker):
Expand All @@ -34,20 +31,21 @@ def tokenize(self, inputs: Union[str, List[str], List[Tuple[str, str]]]):
def rank(
self,
query: str,
docs: List[str],
doc_ids: Optional[List[Union[str, int]]] = None,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
) -> RankedResults:
docs = ensure_docs_list(docs)
doc_ids = ensure_docids(doc_ids, len(docs))
passages = [{"id": doc_id, "text": doc} for doc_id, doc in zip(doc_ids, docs)]
docs = prep_docs(docs, doc_ids, metadata)
passages = [
{"id": doc_idx, "text": doc.text} for doc_idx, doc in enumerate(docs)
]

rerank_request = RerankRequest(query=query, passages=passages)
flashrank_results = self.model.rerank(rerank_request)

ranked_results = [
Result(
doc_id=result["id"],
text=result["text"],
document=docs[idx],
score=result["score"],
rank=idx + 1,
)
Expand Down
2 changes: 1 addition & 1 deletion rerankers/models/rankgpt_rankers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _query_llm(self, messages: List[Dict[str, str]]) -> str:
def rank(
self,
query: str,
docs: Union[Document, List[Document]],
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
rank_start: int = 0,
Expand Down
1 change: 0 additions & 1 deletion rerankers/models/rankllm.py

This file was deleted.

76 changes: 76 additions & 0 deletions rerankers/models/rankllm_ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional, Union, List
from rerankers.models.ranker import BaseRanker
from rerankers.documents import Document
from rerankers.results import RankedResults, Result
from rerankers.utils import prep_docs

from rank_llm.data import Candidate, Query, Request
from rank_llm.rerank.vicuna_reranker import VicunaReranker
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
from rank_llm.rerank.rank_gpt import SafeOpenai
from rank_llm.rerank.reranker import Reranker as rankllm_Reranker


class RankLLMRanker(BaseRanker):
def __init__(
self,
model: str,
api_key: Optional[str] = None,
lang: str = "en",
verbose: int = 1,
) -> "RankLLMRanker":
self.api_key = api_key
self.model = model
self.verbose = verbose
self.lang = lang

if "zephyr" in self.model.lower():
self.rankllm_ranker = ZephyrReranker()
elif "vicuna" in self.model.lower():
self.rankllm_ranker = VicunaReranker()
elif "gpt" in self.model.lower():
self.rankllm_ranker = rankllm_Reranker(
SafeOpenai(model=self.model, context_size=4096, keys=self.api_key)
)

def rank(
self,
query: str,
docs: Union[str, List[str], Document, List[Document]],
doc_ids: Optional[Union[List[str], List[int]]] = None,
metadata: Optional[List[dict]] = None,
rank_start: int = 0,
rank_end: int = 0,
) -> RankedResults:
docs = prep_docs(docs, doc_ids, metadata)

request = Request(
query=Query(text=query, qid=1),
candidates=[
Candidate(doc={"text": doc.text}, docid=doc_idx, score=1)
for doc_idx, doc in enumerate(docs)
],
)

rankllm_results = self.rankllm_ranker.rerank(
request,
rank_end=len(docs) if rank_end == 0 else rank_end,
window_size=min(20, len(docs)),
step=10,
)

ranked_docs = []

for rank, result in enumerate(rankllm_results.candidates, start=rank_start):
ranked_docs.append(
Result(
document=docs[result.docid],
rank=rank,
)
)

return RankedResults(results=ranked_docs, query=query, has_scores=False)

def score(self):
print("Listwise ranking models like RankLLM cannot output scores!")
return None
13 changes: 13 additions & 0 deletions rerankers/reranker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
import warnings
from rerankers.models import AVAILABLE_RANKERS
from rerankers.models.ranker import BaseRanker
from rerankers.utils import vprint
Expand All @@ -21,6 +22,7 @@
"rankgpt": {"en": "gpt-4-turbo-preview", "other": "gpt-4-turbo-preview"},
"rankgpt3": {"en": "gpt-3.5-turbo", "other": "gpt-3.5-turbo"},
"rankgpt4": {"en": "gpt-4", "other": "gpt-4"},
"rankllm": {"en": "gpt-4o", "other": "gpt-4o"},
"colbert": {
"en": "colbert-ir/colbertv2.0",
"fr": "bclavie/FraColBERTv2",
Expand All @@ -38,6 +40,7 @@
"APIRanker": "api",
"ColBERTRanker": "transformers",
"FlashRankRanker": "flashrank",
"RankLLMRanker": "rankllm",
}

PROVIDERS = ["cohere", "jina", "voyage", "mixedbread.ai"]
Expand Down Expand Up @@ -72,6 +75,7 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"colbert": "ColBERTRanker",
"cross-encoder": "TransformerRanker",
"flashrank": "FlashRankRanker",
"rankllm": "RankLLMRanker",
}
return model_mapping.get(explicit_model_type, explicit_model_type)
else:
Expand All @@ -80,6 +84,8 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"lit5": "LiT5Ranker",
"t5": "T5Ranker",
"inranker": "T5Ranker",
"rankllm": "RankLLMRanker",
"rankgpt": "RankGPTRanker",
"gpt": "RankGPTRanker",
"zephyr": "RankZephyr",
"colbert": "ColBERTRanker",
Expand All @@ -88,9 +94,16 @@ def _get_model_type(model_name: str, explicit_model_type: Optional[str] = None)
"voyage": "APIRanker",
"ms-marco-minilm-l-12-v2": "FlashRankRanker",
"ms-marco-multibert-l-12": "FlashRankRanker",
"vicuna": "RankLLMRanker",
"zephyr": "RankLLMRanker",
}
for key, value in model_mapping.items():
if key in model_name:
if key == "gpt":
warnings.warn(
"The key 'gpt' currently defaults to the rough rankGPT implementation. From version 0.0.5 onwards, 'gpt' will default to RankLLM instead. Please specify the 'rankgpt' `model_type` if you want to keep the current behaviour",
DeprecationWarning,
)
return value
if (
any(
Expand Down