# Tweb Journal Paper: ColBERT-PRF: Semantic Pseudo-Relevance Feedback for Dense Passage and Document Retrieval

This notebook demonstrates the experiments in our Tweb journal paper, including:

- Experiment for Measuring the Informativeness of Expansion Embeddings of ColBERT-PRF;
- Experiment for Efficient Variants of ColBERT-PRF


## Installation

Install pyt_colbert installs PyTerrier too. You also need to have [FAISS installed](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md).

In [1]:
!nvidia-smi

Sun Sep 25 16:33:54 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 470.74       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN RTX    Off  | 00000000:DB:00.0 Off |                  N/A |
| 41%   43C    P8    23W / 280W |    662MiB / 24220MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [5]:
import pyterrier as pt
pt.init()
from pyterrier.measures import *


PyTerrier 0.8.1 has loaded Terrier 5.6 (built by craigmacdonald on 2021-09-17 13:27)



## Setup

We have an existing index for the MSMARCO v1 Passage corpus, previously indexed using pyt_colbert (this adds the tokenids file, which is needed).

In [6]:
from pyterrier_colbert.ranking import ColBERTFactory

factory = ColBERTFactory(
    "/colbert_checkpoint_path/colbert.dnn",
    "/path/to/indices/colbert_passage/","index_name3",memtype='mem'
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing ColBERT: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']
You should probably TRAI

[Sep 25, 16:34:06] #> Loading model checkpoint.
[Sep 25, 16:34:06] #> Loading checkpoint /nfs/xiao/GOOD_MODELS/colbert.dnn
[Sep 25, 16:34:07] #> checkpoint['epoch'] = 0
[Sep 25, 16:34:07] #> checkpoint['batch'] = 44500


## Baseline

This is the default ColBERT dense retrieval setting - a set ANN retrieval from the FAISS index, followed an exact scoring using the large ColBERT index.

In [7]:
factory.faiss_index_on_gpu = True
e2e = factory.end_to_end()

[Sep 25, 16:34:12] #> Loading the FAISS index from /nfs/craigm/indices/colbert_passage/index_name3/ivfpq.faiss ..
[Sep 25, 16:34:26] #> Building the emb2pid mapping..
[Sep 25, 16:35:01] len(self.emb2pid) = 687989391


Loading index shards to memory:   0%|          | 0/24 [00:00<?, ?shard/s]

Loading reranking index, memtype=mem


Loading index shards to memory: 100%|██████████| 24/24 [02:33<00:00,  6.40s/shard]


In [8]:
import pandas as pd

qrels2019 = pt.get_dataset("trec-deep-learning-passages").get_qrels('test-2019')
topics2019 = pt.get_dataset("trec-deep-learning-passages").get_topics('test-2019')

topics2020 = pt.get_dataset("trec-deep-learning-passages").get_topics('test-2020')
qrels2020 = pt.get_dataset(  "trec-deep-learning-passages").get_qrels('test-2020')


# Experiments for ColBERT-PRF variants

In [10]:
fnt=factory.nn_term(df=True)

[Sep 25, 16:38:10] #> Building the emb2tid mapping..
687989391
Loading doclens


In [11]:
import torch
import numpy as np
num_docs=fnt.num_docs
num_all_tokens = len(fnt.emb2tid) 
idfdict = {}
ictfdict = {}
for tid in pt.tqdm(range(fnt.inference.query_tokenizer.tok.vocab_size)):
    df = fnt.getDF_by_id(tid)
    # for add one IDF score
    idfscore = np.log((1+num_docs)/(df+1))
    idfdict[tid] = idfscore
    # for ICTF score
    cf = fnt.getCTF_by_id(tid)
    ictfscore = np.log((num_all_tokens+1)/(cf+1))
    ictfdict[tid] = ictfscore

100%|██████████| 30522/30522 [00:00<00:00, 143230.96it/s]


In [12]:
embs_to_analyse = 10_000_000 # the number of tokens to analyse
import torch
id2meancos = {}
for tid in pt.tqdm(range(fnt.inference.query_tokenizer.tok.vocab_size)):
    occurrences = torch.where(fnt.emb2tid[0:embs_to_analyse] == tid)
    if len(occurrences[0] > 0):
        all101 = factory.rrm.part_mmap[0].mmap[occurrences]
        all101Mean = all101.mean(0)
        mean_cos = torch.nn.functional.cosine_similarity(all101Mean.unsqueeze(0), 
                                                         all101.clone().type(torch.DoubleTensor)).mean()
        id2meancos[tid] = mean_cos.item()
    else:
        id2meancos[tid] = 0

100%|██████████| 30522/30522 [47:20<00:00, 10.75it/s] 


In [19]:
from sklearn.cluster import KMeans

from pyterrier.transformer import TransformerBase
import pandas as pd
import torch
from sklearn_extra.cluster import KMedoids
from sklearn.cluster import KMeans

from pyterrier.transformer import TransformerBase
import pandas as pd

from collections import defaultdict
def get_nearest_tokens_for_emb(self, emb, k=10, low_tf=0):
    """
        Displays the most related terms for each query
    """
    scores, ids = self.faiss_index.faiss_index.search(np.array([emb]), k=k)
    id2freq = defaultdict(int)
    for id_set in ids:
        for id in id_set:
            id2freq[self.emb2tid[id].item()] += 1
    skips = set(self.inference.query_tokenizer.tok.special_tokens_map.values())
    rtr = {}
    for t, freq in sorted(id2freq.items(), key=lambda item: -1* item[1]):
        if freq <= low_tf:
            continue
        token = self.inference.query_tokenizer.tok.decode([t])
        if "[unused" in token or token in skips:
            continue
        rtr[token] = freq
    return rtr
        
def rmv_padding(prf_embs):
    outputs =torch.empty([0,prf_embs.shape[1]])
    for emb in prf_embs:
        if emb.float().sum() == 0:
            continue
        else:
            outputs = torch.cat((outputs,emb.unsqueeze(0)))
    return outputs

class ColbertPRF_variants(pt.Transformer):
    def __init__(self, k, exp_terms, beta=1, r = 42, mean_cos_weight=False, idf_weight=False, ictf_weight=False,return_docs = False, fb_docs=10, kmeans=False,kmedoids=False, kmeansclosest=False,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.exp_terms = exp_terms
        self.beta = beta
        self.mean_cos_weight = mean_cos_weight
        self.idf_weight = idf_weight
        self.ictf_weight = ictf_weight
        self.return_docs = return_docs
        self.fb_docs = fb_docs
        self.r = r
        self.kmedoids = kmedoids
        self.kmeansclosest = kmeansclosest
        self.kmeans = kmeans
        assert self.k > self.exp_terms ,"exp_terms should be smaller than number of clusters"


    def KMeans_clustering(self,prf_embs):
        kmn =  KMeans(self.k, random_state=self.r)
        kmn.fit(prf_embs)
        
        emb_and_score = []
        for cluster in range(self.k):
            # take the centroid, needs to be the float32.
            centroid = np.float32( kmn.cluster_centers_[cluster] )
            tok2freq = get_nearest_tokens_for_emb(fnt, centroid)
            if len(tok2freq) == 0:
                continue
            most_likely_tok = max(tok2freq, key=tok2freq.get)
            tid = fnt.inference.query_tokenizer.tok.convert_tokens_to_ids(most_likely_tok)
            
            if self.mean_cos_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, id2meancos[tid])) 
            elif self.idf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, idfdict[tid]) ) 
            elif self.ictf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, ictfdict[tid]) )

        return emb_and_score
            

    def KMedoids_clustering(self,prf_embs,prf_toks):
        prf_embs = rmv_padding(prf_embs)
        kmedoids = KMedoids(n_clusters=self.k, random_state=self.r, init='k-medoids++',method='pam').fit(prf_embs)
        centroids = kmedoids.cluster_centers_
        idx_centroid = kmedoids.medoid_indices_
        
        emb_and_score = []
        for cluster in range(self.k):
            centroid = kmedoids.cluster_centers_[cluster]
            tid = prf_toks[idx_centroid[cluster]]
            token = fnt.inference.query_tokenizer.tok.decode([tid])
         
            if self.mean_cos_weight:
                emb_and_score.append( (centroid, token, tid, id2meancos[int(tid)]))
            elif self.idf_weight:
                emb_and_score.append( (centroid, token, tid, idfdict[int(tid)]) ) 
            elif self.ictf_weight:
                emb_and_score.append( (centroid, token, tid, ictfdict[int(tid)]) ) 
        return emb_and_score
    
    def KMeansClosest_clustering(self,prf_embs):
        prf_embs = rmv_padding(prf_embs)
        kmn =  KMeans(self.k, random_state=self.r)
        kmn.fit(prf_embs)
        emb_and_score = []
        D_Matrix = kmn.transform(prf_embs)
        for cluster in range(self.k):
            idx = np.argmin(D_Matrix[:,cluster])
            centroid = np.float32(prf_embs[idx])
            tid = fnt.emb2tid[idx]
            centroid = kmn.cluster_centers_[cluster]
            token = fnt.inference.query_tokenizer.tok.decode([tid])

            if self.mean_cos_weight:
                emb_and_score.append( (centroid, token, tid, id2meancos[int(tid)]))
            elif self.idf_weight:
                emb_and_score.append( (centroid, token, tid, idfdict[int(tid)]) ) 
            elif self.ictf_weight:
                emb_and_score.append( (centroid, token, tid, ictfdict[int(tid)]) )  
        return emb_and_score
            

    def transform_query(self, topic_and_res):
        topic_and_res = topic_and_res.sort_values('rank')
        if 'doc_embs' in topic_and_res.columns:
            prf_embs = torch.cat(topic_and_res.head(self.fb_docs).doc_embs.tolist())
        else:
            prf_embs = torch.cat([factory.rrm.get_embedding(docid) for docid in topic_and_res.head(self.fb_docs).docid.values])

        prf_toks = torch.cat([factory.nn_term().get_tokens_for_doc(docid) for docid in topic_and_res.head(self.fb_docs).docid.values])    


        if self.kmeans:
            emb_and_score = self.KMeans_clustering(prf_embs)
        elif self.kmedoids:
            emb_and_score = self.KMedoids_clustering(prf_embs,prf_toks)
        elif self.kmeansclosest:
            emb_and_score = self.KMeansClosest_clustering(prf_embs)
        
        
        sorted_by_second = sorted(emb_and_score, key=lambda tup: -tup[3])
        
        exp_toks=[]
        scores=[]
        exp_embds = []
        exp_tokens = []
        
        for i in range(min(self.exp_terms, len(sorted_by_second))):
            emb, tok, tid, score = sorted_by_second[i]
            exp_toks.append(tid)
            exp_tokens.append(tok)


            scores.append(score)
            exp_embds.append(emb)
        
        first_row = topic_and_res.iloc[0]
        
        newtoks = torch.cat([first_row.query_toks,torch.tensor(exp_toks)])
        newemb = torch.cat([
            first_row.query_embs, 
            torch.Tensor(exp_embds)])
        # apply weighting to the query embeddings
        if self.mean_cos_weight or self.idf_weight or self.ictf_weight:
            # we are using mean_cos weighting?
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                self.beta * torch.Tensor(scores)]
            )
        else:
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                torch.full(self.exp_terms, self.beta)]
            )
        
        rtr = pd.DataFrame([
            [first_row.qid, 
             first_row.docno,
             first_row.query, 
             newemb, 
             newtoks,
             exp_tokens, 
             weights ]], columns=["qid","docno", "query", "query_embs","query_toks","expansion toks",  "query_weights"])
        return rtr
        
#         ["qid","query",'docno','query_toks','query_embs']
    def transform(self, topics_and_docs):
        # some validation of the input
        required = ["qid", "query", "docid","docno", "query_embs","query_toks"]
        for col in required:
            assert col in topics_and_docs.columns
        #restore the docid column if missing
        if "docid" not in topics_and_docs:
            topics_and_docs["docid"] = topics_and_docs.docid.astype("int").values
        rtr = []
        for qid, res in topics_and_docs.groupby("qid"):
            new_query_df = self.transform_query(res)     
            if self.return_docs:
                new_query_df = res[["qid", "docno", "docid"]].merge(new_query_df, on=["qid"])
                
                new_query_df = new_query_df.rename(columns={'docno_x':'docno'})
            rtr.append(new_query_df)
        return pd.concat(rtr)


# ColBERT-PRF Variants (measuring informativeness of expansion embeddings)
- Now, we study the effectiveness of different informativeness measuring techniques, including, IDF, ICTF and MeanCos.

In [14]:
prf_kmeans_idf = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, idf_weight=True, kmeans = True, kmedoids=False,kmeansclosest = False, beta=1,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)

prf_kmeans_ictf = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, ictf_weight=True,kmeans = True, kmedoids=False,kmeansclosest = False, beta=1,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)
prf_kmeans_mcos = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=True,kmeans = True, kmedoids=False,kmeansclosest = False, beta=5,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)

In [15]:
from pyterrier.measures import *
pt.Experiment(
    [
    prf_kmeans_idf,
    prf_kmeans_ictf,
    prf_kmeans_mcos,

    ],
    topics2019,
    qrels2019,
    batch_size=1, 
    verbose=True,
    filter_by_qrels=True,
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100,nDCG@1000, AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000],
    names=["prf_kmeans_idf (beta=1)","prf_kmeans_ictf (beta=1)","prf_kmeans_mcos (beta=5)"]
)

pt.Experiment: 100%|██████████| 129/129 [05:13<00:00,  2.43s/batches]


Unnamed: 0,name,RR(rel=2),nDCG@10,nDCG@100,nDCG@1000,AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000
0,prf_kmeans_idf (beta=1),0.885797,0.735153,0.690286,0.762042,0.481151,0.543161,0.670812,0.870633
1,prf_kmeans_ictf (beta=1),0.872971,0.72323,0.672207,0.748912,0.466233,0.527005,0.654767,0.863237
2,prf_kmeans_mcos (beta=5),0.864485,0.737531,0.691674,0.761707,0.483343,0.545149,0.669737,0.867126


This table of results correspond to the results in Fig.10(a) of our Tweb paper. All the other results presented in Fig.10(b), (c) and (d) can be obtained with a similar setting.

# Effecient ColBERT-PRF Variants (A: ColBERT-PRF with different clustering technique)
Here we demonstrate the effect of different clustering techniques on TREC DL 2019 test queries

In [20]:

import numpy as np

e2e = factory.end_to_end()

prf_kmedoids = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmedoids=True, beta=1,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)

prf_kmeans = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmeans = True, kmedoids=False,kmeansclosest = False, beta=1,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)


prf_kmeansclosest = (
    factory.set_retrieve() >> factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=10000)%10
    >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmeans = False, kmedoids=False,kmeansclosest = True, beta=1,fb_docs=3)
    >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
)



In [21]:
from pyterrier.measures import *
pt.Experiment(
    [
    factory.end_to_end(),
    prf_kmeans,
    prf_kmedoids,
    prf_kmeansclosest
    ],
    topics2019,
    qrels2019,
    batch_size=1, 
    verbose=True,
    filter_by_qrels=True,
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100,nDCG@1000, AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000],
    names=["e2e","prf-kmeans","prf-kmedoids","prf-kmeansclosest"]
)

pt.Experiment: 100%|██████████| 172/172 [04:54<00:00,  1.71s/batches]


Unnamed: 0,name,RR(rel=2),nDCG@10,nDCG@100,nDCG@1000,AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000
0,e2e,0.852883,0.693407,0.602954,0.672184,0.386995,0.430988,0.578838,0.789166
1,prf-kmeans,0.885797,0.735153,0.690286,0.762042,0.481151,0.543161,0.670812,0.870633
2,prf-kmedoids,0.872332,0.719888,0.666308,0.749149,0.443443,0.507308,0.655673,0.868132
3,prf-kmeansclosest,0.849663,0.729587,0.660544,0.729889,0.449797,0.507054,0.650107,0.850491


This table of results correspond to the results in Table 7 of our Tweb paper, in particular, the results for ColBERT-PRF ranking scenario on TREC 2019 query set. Experiments for reranking scenario pipelines and on TREC 2020 query set can be obtained with a similar setting.



# Efficient ColBERT-PRF Variants (B:  Approximate ANN ranking)
-  Following the Approximate ANN ranking technique, proposed in [Macdonald21a]: On Approximate Nearest Neighbour Selection for Multi-Stage Dense Retrieval. Craig Macdonald and Nicola Tonellotto. In Proceedings of CIKM 2021. https://arxiv.org/abs/2108.11480

- Here, we study the effectiveness and efficiency tradeoff using Approx. ANN ranking together with different clustering technique.

In [22]:
from pyterrier_colbert.ranking import ColbertPRF
prf_rank_approx_1stage = (factory.ann_retrieve_score()%300\
                   >> factory.index_scorer(query_encoded=True)
                   >> factory.fetch_index_encodings(ids=True)
                   >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmeans=True, beta=1,fb_docs=3)
                   >> factory.set_retrieve(query_encoded=True)
                   >> (factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=5000) %1000))

prf_rank_approx_13stage = (factory.ann_retrieve_score()%300
                   >> factory.index_scorer(query_encoded=True)>> factory.fetch_index_encodings(ids=True)
                   >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmeans=True, beta=1,fb_docs=3)
                   >> factory.ann_retrieve_score(query_encoded=True)%1000
                   >> (factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=5000) %1000))

prf_rerank_approx =(factory.ann_retrieve_score()%1000
                >> factory.index_scorer(query_encoded=True)>> factory.fetch_index_encodings(ids=True)
                >> ColbertPRF_variants(k=24, exp_terms=10, mean_cos_weight=False, idf_weight=True,kmeans=True, beta=1,fb_docs=3,return_docs=True)
                >> (factory.index_scorer(query_encoded=True, add_ranks=True, batch_size=5000) %1000))


In [23]:
from pyterrier.measures import *
pt.Experiment(
    [
    prf_rank_approx_1stage,
    prf_rank_approx_13stage,
    prf_rerank_approx

    ],
    topics2019,
    qrels2019,
    batch_size=1, 
    verbose=True,
    filter_by_qrels=True,
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100,nDCG@1000, AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000],
    names=["ann_kmeans_ranker.1stage","ann_kmeans_ranker.13stage","ann_kmeans_reranker"]
)

pt.Experiment: 100%|██████████| 129/129 [03:20<00:00,  1.55s/batches]


Unnamed: 0,name,RR(rel=2),nDCG@10,nDCG@100,nDCG@1000,AP(rel=2)@100,AP(rel=2)@1000,R(rel=2)@100,R(rel=2)@1000
0,ann_kmeans_ranker.1stage,0.86488,0.731379,0.690439,0.763278,0.483502,0.546565,0.666271,0.869528
1,ann_kmeans_ranker.13stage,0.864631,0.731379,0.679103,0.718573,0.474658,0.519824,0.652961,0.80441
2,ann_kmeans_reranker,0.885826,0.732971,0.623566,0.637119,0.423349,0.456489,0.585278,0.695251


This table of results correspond to the results in Table 7 of our Tweb paper, in particular, the results for ColBERT-PRF implemented with KMeans clustering as well as the Approximate Scoring technique on TREC 2019 query set. 

Experiments on TREC 2020 query set can be obtained with a similar setting. 

In addition, for expriments apply Approximate Scoring technique together with KMeansClosest or KMedoids clustering technique for ColBERT-PRF can also be obtained with a similar setting.
