This repository is using `uv` for dependency management. Make sure to install it, and restore all dependencies using the following command:

```bash
uv sync
```

In [6]:
import torch

fill_model = "answerdotai/ModernBERT-base"
pylate_model = "joe32140/ColModernBERT-base-msmarco-en-bge"
device = "cuda" if torch.cuda.is_available() else "cpu"

### Sample 1 - Basic Usage of base model

Base ModernBERT model is capable of performing "fill-mask" tasks. Here's how you can use it:

In [None]:
import torch
from transformers import pipeline
from pprint import pprint

pipe = pipeline(
    "fill-mask",
    model=fill_model,
    torch_dtype=torch.bfloat16,
    device=device,
)

input_text = "Łódź is a city in [MASK]."
results = pipe(input_text)
pprint(results)

### Sample 2 - Using pylate model for multi-vector retrieval

Below is code for initialzing pylate model and creating index for retrieval:

In [None]:
from pylate import indexes, models

# Step 1: Load the ColBERT model
model = models.ColBERT(
    model_name_or_path=pylate_model,
    device=device,
)

# Step 2: Initialize the Voyager index
index = indexes.Voyager(
    index_folder="pylate-index",
    index_name="index",
    override=True,  # This overwrites the existing index if any
)

# Step 3: Encode the documents
documents_ids = ["1", "2", "3"]
documents = ["Paris", "Łódź", "Washington"]

documents_embeddings = model.encode(
    documents,
    batch_size=32,
    is_query=False,  # Ensure that it is set to False to indicate that these are documents, not queries
    show_progress_bar=True,
)

# Step 4: Add document embeddings to the index by providing embeddings and corresponding ids
index.add_documents(
    documents_ids=documents_ids,
    documents_embeddings=documents_embeddings,
)

Here's the code for performing retrieval:

In [None]:
from pylate import retrieve

# Step 1: Initialize the ColBERT retriever
retriever = retrieve.ColBERT(index=index)

# Step 2: Encode the queries
queries_embeddings = model.encode(
    ["What is in Poland?"],
    batch_size=32,
    is_query=True,
    show_progress_bar=True,
)

# Step 3: Retrieve top-k documents
scores = retriever.retrieve(
    queries_embeddings=queries_embeddings,
    k=10,  # Retrieve the top 10 matches for each query
)

print(scores)