In [2]:
# !pip install transformers==4.10.0

In [1]:
!nvidia-smi

Wed Jan 18 23:11:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.74       Driver Version: 470.74       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN RTX    Off  | 00000000:1A:00.0 Off |                  N/A |
| 41%   30C    P8    14W / 280W |      3MiB / 24220MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

## PyTerrier Setup

Ensure that FAISS-GPU is installed and working, and setup PyTerrier.

In [2]:
import faiss
assert faiss.get_num_gpus() > 0

In [5]:
import pyterrier as pt
pt.init()

PyTerrier 0.9.1 has loaded Terrier 5.7 (built by craigm on 2022-11-10 18:30) and terrier-helper 0.0.7



In [7]:
from pyterrier_colbert.ranking import ColBERTFactory
factory = ColBERTFactory(
    "/nfs/xiao/GOOD_MODELS/colbert.dnn",
    "/nfs/craigm/indices/colbert_passage/","index_name3",memtype='mem'
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing ColBERT: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['linear.weight']
You should probably TRAI

[Jan 18, 23:17:38] #> Loading model checkpoint.
[Jan 18, 23:17:38] #> Loading checkpoint /nfs/xiao/GOOD_MODELS/colbert.dnn
[Jan 18, 23:17:40] #> checkpoint['epoch'] = 0
[Jan 18, 23:17:40] #> checkpoint['batch'] = 44500


In [17]:
factory.faiss_index_on_gpu = True
e2e = factory.end_to_end()

[Jan 18, 23:21:37] #> Loading the FAISS index from /nfs/craigm/indices/colbert_passage/index_name3/ivfpq.faiss ..
[Jan 18, 23:22:00] #> Building the emb2pid mapping..
[Jan 18, 23:22:46] len(self.emb2pid) = 687989391
Loading reranking index, memtype=mem


Loading index shards to memory: 100%|██████████| 24/24 [02:58<00:00,  7.42s/shard]


In [3]:
import pyterrier as pt
pt.init()

dataset = pt.get_dataset("trec-deep-learning-passages")
checkpoint="http://www.dcs.gla.ac.uk/~craigm/ecir2021-tutorial/colbert_model_checkpoint.zip"

PyTerrier 0.8.0 has loaded Terrier 5.6 (built by craigmacdonald on 2021-09-17 13:27)



In [4]:
from pyterrier_colbert.ranking import ColBERTFactory
#update this to the location of your ColBERT index for MSMARCO passage ranking.
index=("/nfs/indices/colbert_passage","index_name3")

factory = ColBERTFactory(checkpoint, *index)

In [8]:
import pandas as pd

qrels2019 = pt.get_dataset("trec-deep-learning-passages").get_qrels('test-2019')
topics2019 = pt.get_dataset("trec-deep-learning-passages").get_topics('test-2019')
topics2019 = topics2019.merge(qrels2019[qrels2019["label"] > 0][["qid"]].drop_duplicates())


topics2020 = pt.get_dataset("trec-deep-learning-passages").get_topics('test-2020')
qrels2020 = pt.get_dataset(  "trec-deep-learning-passages").get_qrels('test-2020')
topics2020 = topics2020.merge(qrels2020[qrels2020["label"] > 0][["qid"]].drop_duplicates())



# SPRF-Inference

In [29]:
import torch 
from torch import nn
if torch.cuda.is_available():
    from torch.cuda.amp import autocast
from transformers import PreTrainedModel, BertModel, BertTokenizer, BertConfig
fnt=factory.nn_term(df=True)
class CWPRFEncoder(PreTrainedModel):
    config_class = BertConfig
    base_model_prefix = 'encoder'
    load_tf_weights = None
    def __init__(self, config: BertConfig):
        super().__init__(config)
        self.config = config 
        self.bert = BertModel(config)
        self.tok_proj = torch.nn.Linear(config.hidden_size, 1)

    def _init_weights(self, module):
            """ Initialize the weights (needed this for the inherited from_pretrained method to work) """
            if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
                # Slightly different from the TF version which uses truncated_normal for initialization
                # cf https://github.com/pytorch/pytorch/pull/5617
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            elif isinstance(module, torch.nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
            if isinstance(module, torch.nn.Linear) and module.bias is not None:
                module.bias.data.zero_()

    def init_weights(self):
        self.bert.init_weights()
        self.tok_proj.apply(self._init_weights)

    def forward(self, **kargs):
        outputs = self.bert(**kargs)
        sequence_output = outputs.last_hidden_state 
        tok_weights = self.tok_proj(sequence_output)
        tok_weights = torch.relu(tok_weights)
        return tok_weights

In [24]:
import os
# if not os.path.exists("stopword-list.txt"):
#     !wget "https://raw.githubusercontent.com/terrier-org/terrier-core/5.x/modules/core/src/main/resources/stopword-list.txt"

stops=[]
with open("stopword-list.txt") as f:
    for l in f:
        stops.append(l.strip())

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
stop_ids = [x for x in tokenizer.convert_tokens_to_ids(stops) if x != 100]
def obtain_tok_w_emb_both(batch_token_ids,batch_embs,batch_weights,variant='AAAT',mode='MAX',stopW=False):
    
    assert len(batch_embs)==len(batch_token_ids)==len(batch_weights)
    rdf=[]
    if variant == 'OAAT':
        for tokenlst, embslist,weights in zip(batch_token_ids,batch_embs,batch_weights):
            rdf_rf = pd.DataFrame({'tid':tokenlst.detach().tolist(),'tok_weight':weights.detach().tolist(), 'tok_emb':embslist.tolist()})
            rdf.append(rdf_rf)
        if stopW ==True:
            df = pd.concat(rdf)
            df = df[~df['tid'].isin(stop_ids)]
        else:
            df = pd.concat(rdf)
    elif variant == 'AAAT':
        rdf = pd.DataFrame({'tid':list(batch_token_ids.numpy()),'tok_weight': list(batch_weights.detach().numpy()), 'tok_emb':list(batch_embs.numpy())})
        if stopW ==True:
            df = rdf[~rdf['tid'].isin(stop_ids)]
        else:
            df = rdf
            
    special_ids = [101,102,1,2]
    df = df[~df['tid'].isin(special_ids)]
    
    df_dedup = df.drop_duplicates(subset='tid',keep=False)
    dup_tids = df[df.duplicated(subset='tid')]['tid'].unique()
    

    # # max case
    if mode =='MAX':
        df=df.reset_index(drop=True)
        rtr_max = pd.DataFrame(columns=['tid','tok_weight','tok_emb'])
        for tid in dup_tids:
            rtr_max = rtr_max.append(df.take([df[df.tid==tid]['tok_weight'].idxmax()]))
        rtr = df_dedup.append(rtr_max)
    elif mode == "AVG":
        df=df.reset_index(drop=True)
        mean_weight,mean_emb =[],[]
        for tid in dup_tids:
            df_tid = df[df.tid==tid]
            mean_weight.append(df[df.tid==tid].tok_weight.mean())
            mean_emb.append(torch.mean( torch.stack([torch.Tensor(df_tid.iloc[i].tok_emb) for i in np.arange(len(df_tid))], dim=0), 0))
        rtr_mean = pd.DataFrame({'tid':dup_tids,'tok_weight':mean_weight, 'tok_emb':mean_emb})
        rtr = df_dedup.append(rtr_mean)
    rtr =  rtr.sort_values(by='tok_weight', ascending=False).reset_index(drop=True)
    return rtr


    


In [25]:

import pandas as pd, torch
from pyterrier_colbert.ranking import ColbertPRF
import sklearn, numpy as np
import math

class CWPRF_AAAT(ColbertPRF):
    def __init__(self,*args,path=None,stopW=False,mode=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = CWPRFEncoder.from_pretrained("castorini/unicoil-msmarco-passage")
        self.tokenizer = BertTokenizer.from_pretrained("castorini/unicoil-msmarco-passage")
        if path is not None:
            checkpoint = torch.load(path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
        self.max_length = 512  # hardcode for now
        self.stopW = stopW
        self.mode = mode
        
    def transform_query(self, topic_and_res : pd.DataFrame) -> pd.DataFrame:
        topic_and_res = topic_and_res.sort_values('rank')
        Q_toks = torch.cat((topic_and_res.head(1)["query_toks"].values).tolist(),dim=0)
        Q_embs = torch.cat((topic_and_res.head(1)["query_embs"].values).tolist(),dim=0)
        # get the toks and embeddings in the feedback passages
        prf_embs = []
        prf_toks = [] 
        for docid in topic_and_res.head(self.fb_docs).docid.values:
            doc_embs = factory.rrm.get_embedding(docid)
            doc_toks = fnt.get_tokens_for_doc(docid)
            # we need to crop out the unused embeddings
            doc_embs = doc_embs[0:doc_toks.shape[0], :]
            prf_embs.append(doc_embs)
            prf_toks.append(doc_toks)
        prf_embs = torch.cat(prf_embs)
        prf_toks = torch.cat(prf_toks)
        prf_toks = torch.cat([Q_toks, prf_toks],dim=0)
        prf_embs = torch.cat([Q_embs, prf_embs],dim=0)       
        if prf_toks.shape[0] > 512:
            prf_toks = prf_toks[0:512]
            prf_embs = prf_embs[0:512, :]     
        outputs = self.model(input_ids=prf_toks.unsqueeze(0), attention_mask=torch.ones_like(prf_toks).unsqueeze(0))
        outputs_weights =outputs.squeeze(dim=2).squeeze(dim=0)
        #  
        tok_and_weights_embs_df = obtain_tok_w_emb_both(prf_toks[32:], prf_embs[32:], outputs_weights[32:],variant = 'AAAT',mode=self.mode, stopW = self.stopW)

        max_weight = tok_and_weights_embs_df.iloc[0].tok_weight 

        
        exp_tokens = tok_and_weights_embs_df.head(self.fb_embs).tid.tolist()
        sprf_weights = tok_and_weights_embs_df.head(self.fb_embs).tok_weight.tolist()/max_weight
        exp_embds_list = tok_and_weights_embs_df.head(self.fb_embs).tok_emb.tolist()
        exp_embds = [torch.Tensor(item).float() for item in exp_embds_list]

  
        # build up the new dataframe columns
        first_row = topic_and_res.iloc[0]

        
        # concatenate the new embeddings to the existing query embeddings 
        newemb = torch.cat([
            first_row.query_embs, 
            torch.stack(exp_embds)
        ])
        
        # the weights column defines important of each query embedding
        weights = torch.cat([ 
            torch.ones(len(first_row.query_embs)),
            self.beta * torch.Tensor(sprf_weights)]
        )
        exp_tokens = torch.IntTensor(exp_tokens)
        
        
        toks = torch.cat([first_row.query_toks,
                          exp_tokens])
        
        # generate the revised query dataframe row
        rtr = pd.DataFrame([
            [first_row.qid, 
            first_row.docno,
            first_row.query, 
            newemb, 
            toks, 
            weights ]
            ],
            columns=["qid","docno", "query", "query_embs", "query_toks", "query_weights"])
        
        return rtr
        


In [26]:
import pandas as pd, torch
from pyterrier_colbert.ranking import ColbertPRF
import sklearn, numpy as np
import math

class CWPRF_OAAT(ColbertPRF):
    def __init__(self,*args,path=None,stopW=False,mode=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = CWPRFEncoder.from_pretrained("castorini/unicoil-msmarco-passage")
        self.tokenizer = BertTokenizer.from_pretrained("castorini/unicoil-msmarco-passage")
        if path is not None:
            checkpoint = torch.load(path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
        self.max_length = 512  # hardcode for now
        self.stopW = stopW
        self.mode = mode

        
    def transform_query(self, topic_and_res : pd.DataFrame) -> pd.DataFrame:
        topic_and_res = topic_and_res.sort_values('rank')
        Q_toks = torch.cat((topic_and_res.head(1)["query_toks"].values).tolist(),dim=0)
        #         print("Q_toks:",Q_toks)
        Q_embs = torch.cat((topic_and_res.head(1)["query_embs"].values).tolist(),dim=0)

        # get the toks and embeddings in the feedback passages
        prf_embs = []
        prf_toks = [] 

        outputs_seperate_weight =[]
        for docid in topic_and_res.head(self.fb_docs).docid.values:
            doc_embs = factory.rrm.get_embedding(docid)
            doc_toks = fnt.get_tokens_for_doc(docid)
            # we need to crop out the unused embeddings
            doc_embs = doc_embs[0:doc_toks.shape[0], :]

            prf_embs.append(doc_embs)
            prf_toks.append(doc_toks)

            if doc_toks.shape[0] > 512:
                doc_toks = doc_toks[0:512]
                doc_embs = doc_embs[0:512, :] 

            doc_toks_withQ = torch.cat([Q_toks,doc_toks],dim=0)
      
            outputs_seperate = self.model(input_ids=doc_toks_withQ.unsqueeze(0), attention_mask=torch.ones_like(doc_toks_withQ).unsqueeze(0))
            outputs_seperate = outputs_seperate.squeeze(dim=2).squeeze(dim=0)
            outputs_seperate_weight.append(outputs_seperate[32:])

        # obtain the sprf predicted weights for each prf token
        # then, we rank and select the exp_embds according to this
        tok_and_weights_embs_df = obtain_tok_w_emb_both(prf_toks, prf_embs, outputs_seperate_weight,variant = 'OAAT',mode=self.mode, stopW = self.stopW)

        # normalise use the max returned weight
        max_weight = tok_and_weights_embs_df.iloc[0].tok_weight 
        
        exp_tokens = tok_and_weights_embs_df.head(self.fb_embs).tid.tolist()
        sprf_weights = tok_and_weights_embs_df.head(self.fb_embs).tok_weight.tolist()/max_weight
        exp_embds_list = tok_and_weights_embs_df.head(self.fb_embs).tok_emb.tolist()
        exp_embds = [torch.Tensor(item).float() for item in exp_embds_list]

  
        # build up the new dataframe columns
        first_row = topic_and_res.iloc[0]
        
        # concatenate the new embeddings to the existing query embeddings 
        newemb = torch.cat([
            first_row.query_embs, 
            torch.stack(exp_embds)
        ])
        
        # the weights column defines important of each query embedding
        weights = torch.cat([ 
            torch.ones(len(first_row.query_embs)),
            self.beta * torch.Tensor(sprf_weights)]
        )
        exp_tokens = torch.IntTensor(exp_tokens)
        
        
        toks = torch.cat([first_row.query_toks,
                          exp_tokens])
        
        # generate the revised query dataframe row
        rtr = pd.DataFrame([
            [first_row.qid, 
            first_row.docno,
            first_row.query, 
            newemb, 
            toks, 
            weights ]
            ],
            columns=["qid","docno", "query", "query_embs", "query_toks", "query_weights"])
        
        return rtr
        

# Validation on TREC 2019
- We now conduct validation experiments on TREC 2019 query set. Following the evaluation measures as the [TREC 2019 Deep Learning track](https://arxiv.org/abs/2003.07820).


- We now conduct validation experiments on TREC 2019 query set. Following the evaluation measures as the [TREC 2019 Deep Learning track](https://arxiv.org/abs/2003.07820).

In [37]:
# checkpoint_path = "/path/to/checkpoint/sprf_OAAT/checkpoints/"
checkpoint_path  = "/nfs/sean/workspace_xiao/SPRF_seed/AAAT_bqe_ibn_pret_r800_b24/psg/train_AAAT.py/SPRF_AAAT/checkpoints/"

pipes_A2A=[]
names=[]
for i in range(1,11):
    path = checkpoint_path+f"SPRF{i}000.dnn"
    prf_AAAT = (e2e%10
             >>CWPRF_AAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=5,return_docs=False,path=path,mode = 'MAX',stopW=True)
             >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
             )

    pipes_A2A.append(prf_AAAT)
    names.append(f"CWPRF_AAAT_{i}k")
    
from pyterrier.measures import *
res_validation_2019 = pt.Experiment(
    [e2e]+pipes_A2A,
    topics2019,
    qrels2019,
    eval_metrics=[AP(rel=2)@1000, nDCG@10,R(rel=2)@1000],
#     baseline=0,
    names=["colbert_e2e"]+names, verbose=True, batch_size=10,highlight="bold"
)

res_validation_2019
    

100%|██████████| 30522/30522 [00:00<00:00, 274750.93it/s]
100%|██████████| 30522/30522 [00:00<00:00, 269079.37it/s]
100%|██████████| 30522/30522 [00:00<00:00, 259337.95it/s]
100%|██████████| 30522/30522 [00:00<00:00, 266309.72it/s]
100%|██████████| 30522/30522 [00:00<00:00, 267886.80it/s]
100%|██████████| 30522/30522 [00:00<00:00, 269173.29it/s]
100%|██████████| 30522/30522 [00:00<00:00, 268107.28it/s]
100%|██████████| 30522/30522 [00:00<00:00, 274141.98it/s]
100%|██████████| 30522/30522 [00:00<00:00, 275679.24it/s]
100%|██████████| 30522/30522 [00:00<00:00, 266648.09it/s]
pt.Experiment: 100%|██████████| 55/55 [14:42<00:00, 16.05s/batches]


Unnamed: 0,name,AP(rel=2)@1000,nDCG@10,R(rel=2)@1000
0,colbert_e2e,0.430988,0.693407,0.789166
1,CWPRF_AAAT_1k,0.52738,0.730404,0.862985
2,CWPRF_AAAT_2k,0.531881,0.744428,0.859627
3,CWPRF_AAAT_3k,0.527689,0.73856,0.85626
4,CWPRF_AAAT_4k,0.523911,0.730351,0.857645
5,CWPRF_AAAT_5k,0.525032,0.72487,0.858765
6,CWPRF_AAAT_6k,0.520622,0.727121,0.858419
7,CWPRF_AAAT_7k,0.517241,0.727596,0.854553
8,CWPRF_AAAT_8k,0.524057,0.736907,0.85979
9,CWPRF_AAAT_9k,0.522978,0.740247,0.861349


# Main results on both TREC 2019 and TREC 2020

- Following the evaluation measures as the TREC DL track used [TREC 2020 Deep Learning track](https://arxiv.org/abs/2003.07820), now we report the main results on TREC 2019 and  TREC 2020 query set.

In [32]:
e2e = factory.end_to_end()

path_AAAT = "/nfs/sean/workspace_xiao/SPRF_seed/AAAT_bqe_ibn_pret_r800_b24/psg/train_AAAT.py/SPRF_AAAT/checkpoints/SPRF2000.dnn"
prf_AAAT = (e2e%10
         >>CWPRF_AAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=5,return_docs=False,path=path_AAAT,mode = 'MAX',stopW=True)
         >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
         )

path_OAAT = "/nfs/sean/workspace_xiao/SPRF_seed/OAAT_bqe_ibn_pret_r9090_b12/psg/train_OAAT.py/SPRF_AAAT/checkpoints/SPRF7000.dnn"
prf_OAAT = (e2e%10
             >> CWPRF_OAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=5,return_docs=False,path=path_OAAT,mode = 'MAX',stopW=True)
             >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
             )

from pyterrier.measures import *
res_main_dl19 = pt.Experiment(
    [
     prf_AAAT, prf_OAAT],
    topics2019,
    qrels2019,
    eval_metrics=[AP(rel=2)@1000, nDCG@10 ,R(rel=2)@1000],
#     baseline=0,
    names=["AAAT.dl19","OAAT.dl19"], verbose=True, batch_size=10,highlight="bold"
)

res_main_dl19

100%|██████████| 30522/30522 [00:00<00:00, 280513.59it/s]
100%|██████████| 30522/30522 [00:00<00:00, 250692.33it/s]
pt.Experiment: 100%|██████████| 10/10 [03:14<00:00, 19.48s/batches]


Unnamed: 0,name,AP(rel=2)@1000,nDCG@10,R(rel=2)@1000
0,AAAT.dl19,0.531881,0.744428,0.859627
1,OAAT.dl19,0.525219,0.724404,0.872197


In [39]:
e2e = factory.end_to_end()

path_AAAT = "/nfs/sean/workspace_xiao/SPRF_seed/AAAT_bqe_ibn_pret_r800_b24/psg/train_AAAT.py/SPRF_AAAT/checkpoints/SPRF2000.dnn"
prf_AAAT = (e2e%10
         >>CWPRF_AAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=5,return_docs=False,path=path_AAAT,mode = 'MAX',stopW=True)
         >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
         )

path_OAAT = "/nfs/sean/workspace_xiao/SPRF_seed/OAAT_bqe_ibn_pret_r9090_b12/psg/train_OAAT.py/SPRF_AAAT/checkpoints/SPRF7000.dnn"
prf_OAAT = (e2e%10
             >> CWPRF_OAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=5,return_docs=False,path=path_OAAT,mode = 'MAX',stopW=True)
             >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
             )

from pyterrier.measures import *
res_main_dl20 = pt.Experiment(
    [
     prf_AAAT, prf_OAAT],
    topics2020,
    qrels2020,
    eval_metrics=[AP(rel=2)@1000, nDCG@10 ,R(rel=2)@1000],
#     baseline=0,
    names=["AAAT.dl20","OAAT.dl20"], verbose=True, batch_size=10,highlight="bold"
)

res_main_dl20

100%|██████████| 30522/30522 [00:00<00:00, 274661.33it/s]
100%|██████████| 30522/30522 [00:00<00:00, 259573.00it/s]
pt.Experiment: 100%|██████████| 12/12 [03:53<00:00, 19.43s/batches]


Unnamed: 0,name,AP(rel=2)@1000,nDCG@10,R(rel=2)@1000
0,AAAT.dl20,0.513644,0.724582,0.878334
1,OAAT.dl20,0.504861,0.720389,0.878262


# Hyper-parameter Search
- Now, we tune the hyper-parameter including:
   - $f_p$, which is the number of feedback documents; 
   - $\beta$, which controls the importance of the expansion embeddings; 
   - $f_e$, which the number of the expansion embeddings for SPRF.

### Nbr. Exp_terms $f_e$

In [40]:
from pyterrier.measures import *
def prfRanker_exp_terms( perquery=False):
    topics = topics2019
    qrels = qrels2019
    dfs = []
    for exp_terms in range(1,5):
        prf_AAAT = (e2e%10
         >>CWPRF_AAAT(factory,k=24,fb_embs=exp_terms,fb_docs=3,beta=5,return_docs=False,path=path_AAAT,mode = 'MAX',stopW=True)
         >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
         )
        print("exp_terms:",exp_terms)
        df = pt.Experiment(
            [(prf_AAAT)],
            topics,
            qrels,
            eval_metrics=[RR(rel=2), RR(rel=2),R(rel=2)@100,R(rel=2)@1000, nDCG@10, nDCG@100, AP(rel=2),"mrt"],            
            names=["AAAT %1f exp_terms"%(exp_terms)],
            perquery=perquery
        )
        print(df)
        dfs.append(df)
    return pd.concat(dfs)

In [41]:
res_nbr_exp=prfRanker_exp_terms()
res_nbr_exp

100%|██████████| 30522/30522 [00:00<00:00, 290055.37it/s]


exp_terms: 1


100%|██████████| 30522/30522 [00:00<00:00, 265512.71it/s]

                      name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 1.000000 exp_terms   0.868364      0.604444       0.794932  0.707223   

   nDCG@100  AP(rel=2)          mrt  
0  0.632842    0.45618  1484.569138  





exp_terms: 2


100%|██████████| 30522/30522 [00:00<00:00, 158791.56it/s]

                      name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 2.000000 exp_terms   0.865249      0.616204       0.809533  0.716412   

   nDCG@100  AP(rel=2)          mrt  
0  0.645126   0.473732  1536.781663  





exp_terms: 3


100%|██████████| 30522/30522 [00:00<00:00, 273805.42it/s]

                      name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 3.000000 exp_terms    0.85614      0.632069       0.824032  0.721957   

   nDCG@100  AP(rel=2)          mrt  
0   0.66234   0.488571  1572.394161  





exp_terms: 4
                      name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 4.000000 exp_terms   0.855383      0.632219       0.841125  0.725121   

   nDCG@100  AP(rel=2)         mrt  
0  0.662369   0.500486  1630.33407  


Unnamed: 0,name,RR(rel=2),R(rel=2)@100,R(rel=2)@1000,nDCG@10,nDCG@100,AP(rel=2),mrt
0,AAAT 1.000000 exp_terms,0.868364,0.604444,0.794932,0.707223,0.632842,0.45618,1484.569138
0,AAAT 2.000000 exp_terms,0.865249,0.616204,0.809533,0.716412,0.645126,0.473732,1536.781663
0,AAAT 3.000000 exp_terms,0.85614,0.632069,0.824032,0.721957,0.66234,0.488571,1572.394161
0,AAAT 4.000000 exp_terms,0.855383,0.632219,0.841125,0.725121,0.662369,0.500486,1630.33407


### Parameter $\beta$ beta

In [48]:
from pyterrier.measures import *
def prfRanker_beta( perquery=False):
    topics = topics2019
    qrels = qrels2019
    dfs = []
    for beta in range(0,11):
        prf_AAAT = (e2e%10
         >>CWPRF_AAAT(factory,k=24,fb_embs=10,fb_docs=3,beta=beta,return_docs=False,path=path_AAAT,mode = 'MAX',stopW=True)
         >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
         )
        print("beta:",beta)
        df = pt.Experiment(
            [(prf_AAAT)],
            topics,
            qrels,
            eval_metrics=[RR(rel=2), RR(rel=2),R(rel=2)@100,R(rel=2)@1000, nDCG@10, nDCG@100, AP(rel=2),"mrt"],            
            names=["AAAT %1f beta"%(beta)],
            perquery=perquery
        )
#         print(df)
        dfs.append(df)
    return pd.concat(dfs)

In [49]:
res_beta=prfRanker_beta()
res_beta

100%|██████████| 30522/30522 [00:00<00:00, 264093.53it/s]


beta: 0


100%|██████████| 30522/30522 [00:00<00:00, 279874.57it/s]


beta: 1


100%|██████████| 30522/30522 [00:00<00:00, 268588.78it/s]


beta: 2


100%|██████████| 30522/30522 [00:00<00:00, 273906.19it/s]


beta: 3


100%|██████████| 30522/30522 [00:00<00:00, 280842.82it/s]


beta: 4


100%|██████████| 30522/30522 [00:00<00:00, 269261.04it/s]


beta: 5


100%|██████████| 30522/30522 [00:00<00:00, 274987.00it/s]


beta: 6


100%|██████████| 30522/30522 [00:00<00:00, 278150.02it/s]


beta: 7


100%|██████████| 30522/30522 [00:00<00:00, 263898.63it/s]


beta: 8


100%|██████████| 30522/30522 [00:00<00:00, 271947.09it/s]


beta: 9


100%|██████████| 30522/30522 [00:00<00:00, 255333.39it/s]


beta: 10


Unnamed: 0,name,RR(rel=2),R(rel=2)@100,R(rel=2)@1000,nDCG@10,nDCG@100,AP(rel=2),mrt
0,AAAT 0.000000 beta,0.852883,0.586183,0.828534,0.693407,0.607981,0.443277,1938.104107
0,AAAT 1.000000 beta,0.873,0.615146,0.847557,0.727186,0.644476,0.488411,1868.802222
0,AAAT 2.000000 beta,0.880726,0.635667,0.857582,0.738134,0.661775,0.507114,1847.642882
0,AAAT 3.000000 beta,0.876833,0.638737,0.865188,0.742072,0.671292,0.520922,1870.105101
0,AAAT 4.000000 beta,0.876826,0.64423,0.86072,0.744311,0.678603,0.52932,1804.732567
0,AAAT 5.000000 beta,0.867128,0.650331,0.859627,0.744428,0.682528,0.533567,1897.757083
0,AAAT 6.000000 beta,0.864021,0.650224,0.85861,0.743791,0.683054,0.53518,1807.732631
0,AAAT 7.000000 beta,0.86324,0.649727,0.856016,0.739653,0.683411,0.536643,1828.004621
0,AAAT 8.000000 beta,0.862269,0.650317,0.855054,0.739414,0.684277,0.537505,1852.063124
0,AAAT 9.000000 beta,0.862267,0.648913,0.852675,0.740044,0.683743,0.537173,1866.029632


In [52]:
from pyterrier.measures import *
def prfRanker_fbdocs( perquery=False):
    topics = topics2019
    qrels = qrels2019
    dfs = []
    for fb_doc in range(1,5):
        prf_AAAT = (e2e%10
         >>CWPRF_AAAT(factory,k=24,fb_embs=10,fb_docs=fb_doc,beta=5,return_docs=False,path=path_AAAT,mode = 'MAX',stopW=True)
         >> factory.set_retrieve(query_encoded=True)>>factory.index_scorer(query_encoded=True,add_ranks=True)
         )
        print("fb_docs:",fb_doc)
        df = pt.Experiment(
            [(prf_AAAT)],
            topics,
            qrels,
            eval_metrics=[RR(rel=2), RR(rel=2),R(rel=2)@100,R(rel=2)@1000, nDCG@10, nDCG@100, AP(rel=2),"mrt"],            
            names=["AAAT %1f fb_doc"%(fb_doc)],
            perquery=perquery
        )
        print(df)
        dfs.append(df)
    return pd.concat(dfs)

In [53]:
res_fb_docs=prfRanker_fbdocs()
res_fb_docs

100%|██████████| 30522/30522 [00:00<00:00, 273221.64it/s]


fb_docs: 1


100%|██████████| 30522/30522 [00:00<00:00, 259649.86it/s]

                   name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 1.000000 fb_doc    0.86056      0.641845       0.863299  0.724838   

   nDCG@100  AP(rel=2)          mrt  
0  0.668962    0.51078  1747.003051  





fb_docs: 2


100%|██████████| 30522/30522 [00:00<00:00, 277200.22it/s]

                   name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 2.000000 fb_doc   0.833655      0.644076        0.84943  0.732905   

   nDCG@100  AP(rel=2)          mrt  
0  0.674183   0.526402  1773.806532  





fb_docs: 3


100%|██████████| 30522/30522 [00:00<00:00, 271432.18it/s]

                   name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 3.000000 fb_doc   0.867128      0.650331       0.859627  0.744428   

   nDCG@100  AP(rel=2)          mrt  
0  0.682528   0.533567  1819.796863  





fb_docs: 4
                   name  RR(rel=2)  R(rel=2)@100  R(rel=2)@1000   nDCG@10  \
0  AAAT 4.000000 fb_doc   0.873547      0.647226        0.86026  0.749132   

   nDCG@100  AP(rel=2)          mrt  
0  0.680396   0.537098  1977.767955  


Unnamed: 0,name,RR(rel=2),R(rel=2)@100,R(rel=2)@1000,nDCG@10,nDCG@100,AP(rel=2),mrt
0,AAAT 1.000000 fb_doc,0.86056,0.641845,0.863299,0.724838,0.668962,0.51078,1747.003051
0,AAAT 2.000000 fb_doc,0.833655,0.644076,0.84943,0.732905,0.674183,0.526402,1773.806532
0,AAAT 3.000000 fb_doc,0.867128,0.650331,0.859627,0.744428,0.682528,0.533567,1819.796863
0,AAAT 4.000000 fb_doc,0.873547,0.647226,0.86026,0.749132,0.680396,0.537098,1977.767955
