In [15]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.nn.functional import cosine_similarity

In [16]:
tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
model = AutoModel.from_pretrained("colbert-ir/colbertv2.0").eval()

In [22]:
# Define your query and a list of documents
query = "What is the capital of France?"
documents = [
    "Paris is the capital and most populous city of France.",
    "Berlin is the capital and largest city of Germany.",
    "Madrid is the capital of Spain and the largest municipality in both the Community of Madrid and Spain as a whole.",
    "The capital of Italy is Rome. It is also the country's most populated city.",
    "This is some filling garbage document",
    "is the capital and most populous city of"
]

In [23]:
query_tokens = tokenizer(query, return_tensors='pt', padding=True, truncation=True)
query_embeddings = model(**query_tokens).last_hidden_state
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=-1)

In [24]:
# Prepare for storing scores
max_similarity_scores_all_docs = []

In [25]:
# Process documents in a loop
for document in documents:
    document_tokens = tokenizer(document, return_tensors='pt', padding=True, truncation=True)
    document_embeddings = model(**document_tokens).last_hidden_state
    document_embeddings = torch.nn.functional.normalize(document_embeddings, p=2, dim=-1)

    # Compute cosine similarity and MaxSim scores
    similarity_scores = cosine_similarity(query_embeddings.squeeze(0)[:, None, :], document_embeddings.squeeze(0)[None, :, :], dim=-1)
    max_similarity_scores = similarity_scores.max(dim=1).values  # Maximum score for each query token
    document_score = max_similarity_scores.mean()  # Aggregate token scores for the document
    max_similarity_scores_all_docs.append(document_score.item())

In [26]:
# Print or return the aggregated scores for each document
for i, score in enumerate(max_similarity_scores_all_docs):
    print(f"Document {i+1}: Score = {score:.4f}")

Document 1: Score = 0.8900
Document 2: Score = 0.7455
Document 3: Score = 0.7532
Document 4: Score = 0.7800
Document 5: Score = 0.2972
Document 6: Score = 0.7424
