In [4]:
from abc import ABC, abstractmethod
from transformers import AutoModel, AutoTokenizer
import torch

In [5]:
def _auto_detect_device() -> str:
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    return device

In [6]:
_auto_detect_device()

'cuda'

In [10]:
class BaseReranker(ABC):
    def _detect_device(self, device: str) -> str:
        if device == "auto":
            device = _auto_detect_device()
        return device

    def rerank(self, query: str, documents: list[str]) -> tuple[list[int], list[float]]:
        scores = self._rerank(query, documents)
        indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=False)
        return indices, scores

    @abstractmethod
    def _rerank(self, query: str, documents: list[str]) -> list[float]:
        pass

In [11]:
def _insert_token(
    output: dict,
    insert_token_id: int,
    insert_position: int = 1,
    token_type_id: int = 0,
    attention_value: int = 1,
):
    updated_output = {}
    for key in output:
        updated_tensor_list = []
        for seqs in output[key]:
            if len(seqs.shape) == 1:
                seqs = seqs.unsqueeze(0)
            for seq in seqs:
                first_part = seq[:insert_position]
                second_part = seq[insert_position:]
                new_element = (
                    torch.tensor([insert_token_id])
                    if key == "input_ids"
                    else torch.tensor([token_type_id])
                )
                if key == "attention_mask":
                    new_element = torch.tensor([attention_value])
                updated_seq = torch.cat((first_part, new_element, second_part), dim=0)
                updated_tensor_list.append(updated_seq)
        updated_output[key] = torch.stack(updated_tensor_list)
    return updated_output

In [33]:
def _colbert_score(q_reps, p_reps, q_mask: torch.Tensor, p_mask: torch.Tensor):
    print("shape of q_reps, p_reps, q_mask, p_mask")
    print(q_reps.shape, p_reps.shape, q_mask.shape, p_mask.shape)
    print(q_reps, p_reps, q_mask, p_mask)

    # calc max sim
    token_scores = torch.einsum("qin,pjn->qipj", q_reps, p_reps)
    token_scores = token_scores.masked_fill(p_mask.unsqueeze(0).unsqueeze(0) == 0, -1e4)
    scores, _ = token_scores.max(-1)
    scores = scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
    return scores

In [34]:
class ColbertReranker(BaseReranker):
    def __init__(
        self,
        model_name: str,
        device: str = "auto",
        use_fp16=True,
        max_length=512,
        query_token: str = "[unused0]",
        document_token: str = "[unused1]",
        normalize: bool = True,
    ):
        device = self._detect_device(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = device
        self.model.to(device)
        if use_fp16 and "cuda" in device:
            self.model.half()
        self.model.eval()
        self.model.max_length = max_length
        self.max_length = max_length
        self.query_token_id: int = self.tokenizer.convert_tokens_to_ids(query_token)  # type: ignore
        self.document_token_id: int = self.tokenizer.convert_tokens_to_ids(document_token)  # type: ignore
        self.normalize = normalize

    def _encode(self, texts: list[str], insert_token_id: int):
        encoding = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            max_length=self.max_length - 1,  # for insert token
            truncation=True,
        )
        encoding = _insert_token(encoding, insert_token_id)  # type: ignore
        encoding = {key: value.to(self.device) for key, value in encoding.items()}
        return encoding

    def _query_encode(self, query: list[str]):
        return self._encode(query, self.query_token_id)

    def _document_encode(self, documents: list[str]):
        return self._encode(documents, self.document_token_id)

    def _to_embs(self, encoding) -> torch.Tensor:
        with torch.no_grad():
            embs = self.model(**encoding).last_hidden_state.squeeze(1)
        if self.normalize:
            embs = embs / embs.norm(dim=-1, keepdim=True)
        return embs

    def _rerank(self, query: str, documents: list[str]) -> list[float]:
        query_encoding = self._query_encode([query])
        documents_encoding = self._document_encode(documents)
        query_embeddings = self._to_embs(query_encoding)
        document_embeddings = self._to_embs(documents_encoding)
        scores = (
            _colbert_score(
                query_embeddings,
                document_embeddings,
                query_encoding["attention_mask"],
                documents_encoding["attention_mask"],
            )
            .cpu()
            .tolist()[0]
        )
        return scores

In [35]:
model_name = "colbert-ir/colbertv2.0"

In [36]:
reranker = ColbertReranker(model_name=model_name)

In [37]:
query = "What is natural language processing?"
documents = [
    "Natural language processing enables computers to understand human language.",
    "Machine learning involves teaching computers to learn from data.",
    "The history of natural language processing is closely linked to the field of linguistics."
]

In [38]:
indices, scores = reranker.rerank(query, documents)

shape of q_reps, p_reps, q_mask, p_mask
torch.Size([1, 9, 768]) torch.Size([3, 18, 768]) torch.Size([1, 9]) torch.Size([3, 18])
tensor([[[-0.0202,  0.0142, -0.0175,  ..., -0.0148, -0.0366,  0.0108],
         [-0.0149, -0.0103, -0.0072,  ...,  0.0144,  0.0254,  0.0212],
         [-0.0472,  0.0152,  0.0272,  ...,  0.0078, -0.0006,  0.0309],
         ...,
         [ 0.0449, -0.0036,  0.0112,  ..., -0.0260, -0.0234, -0.0324],
         [-0.0297, -0.0320,  0.0055,  ..., -0.0367, -0.0282, -0.0243],
         [-0.0082,  0.0248, -0.0427,  ..., -0.0274, -0.0592, -0.0046]]],
       device='cuda:0', dtype=torch.float16) tensor([[[-0.0314,  0.0219, -0.0083,  ..., -0.0190, -0.0556,  0.0139],
         [-0.0259,  0.0394,  0.0102,  ...,  0.0078, -0.0348,  0.0133],
         [-0.0047,  0.0051,  0.0015,  ...,  0.0262,  0.0074,  0.0116],
         ...,
         [-0.0087, -0.0051, -0.0146,  ...,  0.0090,  0.0162,  0.0319],
         [-0.0292, -0.0147,  0.0010,  ..., -0.0268, -0.0195,  0.0202],
         [-0.025

In [32]:
for idx, score in zip(indices, scores):
    print(f"Document {idx} Score: {score}")
    print(f"Content: {documents[idx]}\n")

Document 1 Score: 0.9501953125
Content: Machine learning involves teaching computers to learn from data.

Document 2 Score: 0.587890625
Content: The history of natural language processing is closely linked to the field of linguistics.

Document 0 Score: 0.85693359375
Content: Natural language processing enables computers to understand human language.

