# ColBERTKP Re-ranking Example

This notebook demonstrates how to perform re-ranking using ColBERTKP models on the TREC DL 2019 dataset. 

## 1. Setup and Imports

In [3]:
# Set JAVA_HOME before importing PyTerrier
import os
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-amd64'

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Now initialize PyTerrier
import pyterrier as pt
if not pt.started():
    pt.init()

print(f"✓ JAVA_HOME set to: {os.environ['JAVA_HOME']}")
print(f"✓ PyTerrier initialized successfully")

✓ JAVA_HOME set to: /usr/lib/jvm/java-21-openjdk-amd64
✓ PyTerrier initialized successfully


In [4]:
import json
from pyterrier.measures import *
from pyterrier_pisa import PisaIndex

# Import ColBERT after PyTerrier initialization
from pyterrier_colbert.ranking import ColBERTFactory

print("✓ All libraries imported successfully")

✓ All libraries imported successfully


## 2. Load Dataset and BM25 Index

We'll use the TREC DL 2019 dataset with the MSMarco passage collection and standard PyTerrier BM25.

In [5]:
# Load the dataset
dataset = pt.get_dataset("irds:msmarco-passage/trec-dl-2019/judged")

# Get topics (queries) and qrels (relevance judgments)
topics = dataset.get_topics()
qrels = dataset.get_qrels()

print(f"Number of topics: {len(topics)}")
print(f"Number of qrels: {len(qrels)}")
print("\nExample topics:")
topics.head(3)

Number of topics: 43
Number of qrels: 9260

Example topics:


Unnamed: 0,qid,query
0,156493,do goldfish grow
1,1110199,what is wifi vs bluetooth
2,1063750,why did the us volunterilay enter ww1


In [6]:
index = PisaIndex.from_dataset('msmarco_passage')
bm25 = index.bm25(num_results=1000)

print("✓ BM25 retriever loaded successfully")

✓ BM25 retriever loaded successfully


## 3. Initialize ColBERT Models

We'll load the three models used in the experiments:
- **ColBERT**: Standard ColBERT model
- **ColBERTKP**: ColBERT trained with keyphrases

In [7]:
# Note: For re-ranking, we don't need index directories since we're using text_scorer()
# which scores directly from text without requiring pre-built indices

# ColBERT model - using torch.load with weights_only=False to handle position_ids
import torch

# ColBERT model
colbert_checkpoint = "../resources/models/colbert-cosine-200k.dnn"

# Load checkpoint with strict=False to ignore position_ids mismatch
checkpoint_dict = torch.load(colbert_checkpoint, map_location='cpu', weights_only=False)
if 'model_state_dict' in checkpoint_dict and 'bert.embeddings.position_ids' in checkpoint_dict['model_state_dict']:
    del checkpoint_dict['model_state_dict']['bert.embeddings.position_ids']
    torch.save(checkpoint_dict, colbert_checkpoint)

colbert_factory = ColBERTFactory(
    colbert_checkpoint,
    index_root=None,  # Not needed for re-ranking
    index_name=None   # Not needed for re-ranking
)

print("✓ ColBERT model loaded")

# ColBERTKP model
colbertkp_checkpoint = "../resources/models/colbertkp-cosine-25k.dnn"

# Load checkpoint with strict=False to ignore position_ids mismatch
checkpoint_dict = torch.load(colbertkp_checkpoint, map_location='cpu', weights_only=False)
if 'model_state_dict' in checkpoint_dict and 'bert.embeddings.position_ids' in checkpoint_dict['model_state_dict']:
    del checkpoint_dict['model_state_dict']['bert.embeddings.position_ids']
    torch.save(checkpoint_dict, colbertkp_checkpoint)

colbertkp_factory = ColBERTFactory(
    colbertkp_checkpoint,
    index_root=None,
    index_name=None
)

print("✓ ColBERTKP model loaded")

Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[nov 12, 10:38:38] #> Loading model checkpoint.
[nov 12, 10:38:38] #> Loading checkpoint ../resources/models/colbert-cosine-200k.dnn
[nov 12, 10:38:39] #> checkpoint['epoch'] = 0
[nov 12, 10:38:39] #> checkpoint['batch'] = 200000
✓ ColBERT model loaded


Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[nov 12, 10:38:40] #> Loading model checkpoint.
[nov 12, 10:38:40] #> Loading checkpoint ../resources/models/colbertkp-cosine-25k.dnn
[nov 12, 10:38:41] #> checkpoint['epoch'] = 0
[nov 12, 10:38:41] #> checkpoint['batch'] = 25000
✓ ColBERTKP model loaded


## 4. Build Re-ranking Pipelines

Each pipeline consists of:
1. BM25 retrieval (top 1000 documents)
2. Text retrieval to get document content from the dataset
3. ColBERT re-ranking using `text_scorer()`

In [8]:
# Create re-ranking pipelines
colbert_pipeline = bm25 >> pt.text.get_text(dataset, "text") >> colbert_factory.text_scorer()
colbertkp_pipeline = bm25 >> pt.text.get_text(dataset, "text") >> colbertkp_factory.text_scorer()

print("✓ Re-ranking pipelines created")

✓ Re-ranking pipelines created


## 5. Run Evaluation (Questions)

We'll evaluate all three models using standard metrics for TREC DL 2019:
- **AP@1000**: Average Precision at rank 1000 (rel≥2)
- **nDCG@10**: Normalized Discounted Cumulative Gain at rank 10
- **RR@10**: Reciprocal Rank at rank 10 (rel≥2)

In [9]:
# Define evaluation metrics for TREC DL 2019
metrics = [AP(rel=2)@1000, nDCG@10, RR(rel=2)@10]

# Run experiment comparing all three models
pt.Experiment(
    [bm25, colbert_pipeline, colbertkp_pipeline],
    topics,
    qrels,
    eval_metrics=metrics,
    names=["BM25", "ColBERT", "ColBERTKP"],
    batch_size=1024,
    round=4,
)

Unnamed: 0,name,AP(rel=2)@1000,nDCG@10,RR(rel=2)@10
0,BM25,0.3031,0.4989,0.678
1,ColBERT,0.4595,0.7039,0.8353
2,ColBERTKP,0.4609,0.7106,0.8469


## 6. Test with Keyphrases

You can also test with automatic keyphrases extracted from queries using the Mistral model.

In [10]:
# Load keyphrases extracted by Mistral
keyphrases_path = "../resources/data/trec_2019_test_mistral_kps.json"

with open(keyphrases_path) as f:
    kps_json = json.load(f)

# Create modified topics with keyphrases instead of original queries
topics_kp = topics.copy()
for i, row in topics_kp.iterrows():
    if row["qid"] in kps_json:
        topics_kp.at[i, "query"] = kps_json[row["qid"]]

print("Original query examples:")
topics[["qid", "query"]].head(3)
print("\nKeyphrase query examples:")
topics_kp[["qid", "query"]].head(3)

Original query examples:

Keyphrase query examples:


Unnamed: 0,qid,query
0,156493,goldfish growth
1,1110199,wifi vs bluetooth
2,1063750,us volunteer entry ww1 reason


In [11]:
# Run experiment with keyphrase queries
pt.Experiment(
    [bm25, colbert_pipeline, colbertkp_pipeline],
    topics_kp,
    qrels,
    eval_metrics=metrics,
    names=["BM25", "ColBERT", "ColBERTKP"],
    batch_size=1024,
    round=4
)

Unnamed: 0,name,AP(rel=2)@1000,nDCG@10,RR(rel=2)@10
0,BM25,0.2859,0.4751,0.612
1,ColBERT,0.4499,0.6849,0.8516
2,ColBERTKP,0.4724,0.7059,0.8876
