In [None]:
# This file contains the following:
# - IDCM model testing on TREC45 dataset 
# - ColBERT,BERTCAT and BERTDOT models to show comparison in efficiency between IDCM and these models

In [2]:
## IDCM Model for Testing

from typing import Dict, Union
import torch
from torch import nn as nn
from transformers import AutoTokenizer, AutoModel
from transformers import PreTrainedModel,PretrainedConfig

pre_trained_model_name = "sebastian-hofstaetter/idcm-distilbert-msmarco_doc"

class IDCM_Config(PretrainedConfig):
    bert_model:str
    # how many passages get scored by BERT 
    sample_n:int

    # type of fast module
    sample_context:str

    # how many passages to take from bert to create the final score (usually the same as sample_n, but could be set to 1 for max-p)
    top_k_chunks:int

    # window size
    chunk_size:int

    # left and right overlap (added to each window)
    overlap:int 

    padding_idx:int = 0

class IDCM_InferenceOnly(PreTrainedModel):
    '''
    IDCM is a neural re-ranking model for long documents, it creates an intra-document cascade between a fast (CK) and a slow module (BERT_Cat)
    This code is only usable for inference (we removed the training mechanism for simplicity)
    '''

    config_class = IDCM_Config
    base_model_prefix = "bert_model"

    def __init__(self,
                 cfg) -> None:
        super().__init__(cfg)

        #
        # bert - scoring
        #
        if isinstance(cfg.bert_model, str):
            self.bert_model = AutoModel.from_pretrained(cfg.bert_model)
        else:
            self.bert_model = cfg.bert_model

        #
        # final scoring (combination of bert scores)
        #
        self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)
        self.top_k_chunks = cfg.top_k_chunks
        self.top_k_scoring = nn.Parameter(torch.full([1,self.top_k_chunks], 1, dtype=torch.float32, requires_grad=True))

        #
        # local self attention
        #
        self.padding_idx= cfg.padding_idx
        self.chunk_size = cfg.chunk_size
        self.overlap = cfg.overlap
        self.extended_chunk_size = self.chunk_size + 2 * self.overlap

        #
        # sampling stuff
        #
        self.sample_n = cfg.sample_n
        self.sample_context = cfg.sample_context

        if self.sample_context == "ck":
            i = 3
            self.sample_cnn3 = nn.Sequential(
                        nn.ConstantPad1d((0,i - 1), 0),
                        nn.Conv1d(kernel_size=i, in_channels=self.bert_model.config.dim, out_channels=self.bert_model.config.dim),
                        nn.ReLU()
                        ) 
        elif self.sample_context == "ck-small":
            i = 3
            self.sample_projector = nn.Linear(self.bert_model.config.dim,384)
            self.sample_cnn3 = nn.Sequential(
                        nn.ConstantPad1d((0,i - 1), 0),
                        nn.Conv1d(kernel_size=i, in_channels=384, out_channels=128),
                        nn.ReLU()
                        ) 

        self.sampling_binweights = nn.Linear(11, 1, bias=True)
        torch.nn.init.uniform_(self.sampling_binweights.weight, -0.01, 0.01)
        self.kernel_alpha_scaler = nn.Parameter(torch.full([1,1,11], 1, dtype=torch.float32, requires_grad=True))

        self.register_buffer("mu",nn.Parameter(torch.tensor([1.0, 0.9, 0.7, 0.5, 0.3, 0.1, -0.1, -0.3, -0.5, -0.7, -0.9]), requires_grad=False).view(1, 1, 1, -1))
        self.register_buffer("sigma", nn.Parameter(torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]), requires_grad=False).view(1, 1, 1, -1))
        

    def forward(self,
                query: Dict[str, torch.LongTensor],
                document: Dict[str, torch.LongTensor],
                use_fp16:bool = True,
                output_secondary_output: bool = False):

        #
        # patch up documents - local self attention
        #
        document_ids = document["input_ids"][:,1:]
        if document_ids.shape[1] > self.overlap:
            needed_padding = self.extended_chunk_size - (((document_ids.shape[1]) % self.chunk_size)  - self.overlap)
        else:
            needed_padding = self.extended_chunk_size - self.overlap - document_ids.shape[1]
        orig_doc_len = document_ids.shape[1]

        document_ids = nn.functional.pad(document_ids,(self.overlap, needed_padding),value=self.padding_idx)
        chunked_ids = document_ids.unfold(1,self.extended_chunk_size,self.chunk_size)

        batch_size = chunked_ids.shape[0]
        chunk_pieces = chunked_ids.shape[1]


        chunked_ids_unrolled=chunked_ids.reshape(-1,self.extended_chunk_size)
        packed_indices = (chunked_ids_unrolled[:,self.overlap:-self.overlap] != self.padding_idx).any(-1)
        orig_packed_indices = packed_indices.clone()
        ids_packed = chunked_ids_unrolled[packed_indices]
        mask_packed = (ids_packed != self.padding_idx)

        total_chunks=chunked_ids_unrolled.shape[0]

        packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
        packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]

        #
        # sampling
        # 
        if self.sample_n > -1:
            
            #
            # ck learned matches
            #
            if self.sample_context == "ck-small":
                query_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
                document_ctx = torch.nn.functional.normalize(self.sample_cnn3(self.sample_projector(self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
            elif self.sample_context == "ck":
                query_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(packed_query_ids).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
                document_ctx = torch.nn.functional.normalize(self.sample_cnn3((self.bert_model.embeddings(ids_packed).detach()).transpose(1,2)).transpose(1, 2),p=2,dim=-1)
            else:
                qe = self.tk_projector(self.bert_model.embeddings(packed_query_ids).detach())
                de = self.tk_projector(self.bert_model.embeddings(ids_packed).detach())
                query_ctx = self.tk_contextualizer(qe.transpose(1,0),src_key_padding_mask=~packed_query_mask.bool()).transpose(1,0)
                document_ctx = self.tk_contextualizer(de.transpose(1,0),src_key_padding_mask=~mask_packed.bool()).transpose(1,0)
        
                query_ctx =   torch.nn.functional.normalize(query_ctx,p=2,dim=-1)
                document_ctx= torch.nn.functional.normalize(document_ctx,p=2,dim=-1)

            cosine_matrix = torch.bmm(query_ctx,document_ctx.transpose(-1, -2)).unsqueeze(-1)

            kernel_activations = torch.exp(- torch.pow(cosine_matrix - self.mu, 2) / (2 * torch.pow(self.sigma, 2))) * mask_packed.unsqueeze(-1).unsqueeze(1)
            kernel_res = torch.log(torch.clamp(torch.sum(kernel_activations, 2) * self.kernel_alpha_scaler, min=1e-4)) * packed_query_mask.unsqueeze(-1)
            packed_patch_scores = self.sampling_binweights(torch.sum(kernel_res, 1))

            
            sampling_scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
            sampling_scores_per_doc[packed_indices] = packed_patch_scores
            sampling_scores_per_doc = sampling_scores_per_doc.reshape(batch_size,-1,)
            sampling_scores_per_doc_orig = sampling_scores_per_doc.clone()
            sampling_scores_per_doc[sampling_scores_per_doc == 0] = -9000

            sampling_sorted = sampling_scores_per_doc.sort(descending=True)
            sampled_indices = sampling_sorted.indices + torch.arange(0,sampling_scores_per_doc.shape[0]*sampling_scores_per_doc.shape[1],sampling_scores_per_doc.shape[1],device=sampling_scores_per_doc.device).unsqueeze(-1)

            sampled_indices = sampled_indices[:,:self.sample_n]
            sampled_indices_mask = torch.zeros_like(packed_indices).scatter(0, sampled_indices.reshape(-1), 1)

            # pack indices

            packed_indices = sampled_indices_mask * packed_indices
    
            packed_query_ids = query["input_ids"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["input_ids"].shape[1])[packed_indices]
            packed_query_mask = query["attention_mask"].unsqueeze(1).expand(-1,chunk_pieces,-1).reshape(-1,query["attention_mask"].shape[1])[packed_indices]

            ids_packed = chunked_ids_unrolled[packed_indices]
            mask_packed = (ids_packed != self.padding_idx)

        #
        # expensive bert scores
        #
        
        bert_vecs = self.forward_representation(torch.cat([packed_query_ids,ids_packed],dim=1),torch.cat([packed_query_mask,mask_packed],dim=1))
        packed_patch_scores = self._classification_layer(bert_vecs) 

        scores_per_doc = torch.zeros((total_chunks,1), dtype=packed_patch_scores.dtype, layout=packed_patch_scores.layout, device=packed_patch_scores.device)
        scores_per_doc[packed_indices] = packed_patch_scores
        scores_per_doc = scores_per_doc.reshape(batch_size,-1,)
        scores_per_doc_orig = scores_per_doc.clone()
        scores_per_doc_orig_sorter = scores_per_doc.clone()

        if self.sample_n > -1:
            scores_per_doc = scores_per_doc * sampled_indices_mask.view(batch_size,-1)
        
        #
        # aggregate bert scores
        #

        if scores_per_doc.shape[1] < self.top_k_chunks:
            scores_per_doc = nn.functional.pad(scores_per_doc,(0, self.top_k_chunks - scores_per_doc.shape[1]))

        scores_per_doc[scores_per_doc == 0] = -9000
        scores_per_doc_orig_sorter[scores_per_doc_orig_sorter == 0] = -9000
        score = torch.sort(scores_per_doc,descending=True,dim=-1).values
        score[score <= -8900] = 0

        score = (score[:,:self.top_k_chunks] * self.top_k_scoring).sum(dim=1)

        if self.sample_n == -1:
            if output_secondary_output:
                return score,{
                    "packed_indices": orig_packed_indices.view(batch_size,-1),
                    "bert_scores":scores_per_doc_orig
                }
            else:
                return score,scores_per_doc_orig    
        else:
            if output_secondary_output:
                return score,scores_per_doc_orig,{
                    "score": score,
                    "packed_indices": orig_packed_indices.view(batch_size,-1),
                    "sampling_scores":sampling_scores_per_doc_orig,
                    "bert_scores":scores_per_doc_orig
                }

            return score

    def forward_representation(self, ids,mask,type_ids=None) -> Dict[str, torch.Tensor]:
        
        if self.bert_model.base_model_prefix == 'distilbert': # diff input / output 
            pooled = self.bert_model(input_ids=ids,
                                     attention_mask=mask)[0][:,0,:]
        elif self.bert_model.base_model_prefix == 'longformer':
            _, pooled = self.bert_model(input_ids=ids,
                                        attention_mask=mask.long(),
                                        global_attention_mask = ((1-ids)*mask).long())
        elif self.bert_model.base_model_prefix == 'roberta': # no token type ids
            _, pooled = self.bert_model(input_ids=ids,
                                        attention_mask=mask)
        else:
            _, pooled = self.bert_model(input_ids=ids,
                                        token_type_ids=type_ids,
                                        attention_mask=mask)

        return pooled


## IDCM Tokenizer and Model
idcm_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
idcm_model = IDCM_InferenceOnly.from_pretrained(pre_trained_model_name)

You are using a model of type IDCM to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
## BERT-CAT Model

from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
from typing import Dict
import torch

class BERT_Cat_Config(PretrainedConfig):
    model_type = "BERT_Cat"
    bert_model: str
    trainable: bool = True

class BERT_Cat(PreTrainedModel):
    """
    The vanilla/mono BERT concatenated (we lovingly refer to as BERT_Cat) architecture 
    -> requires input concatenation before model, so that batched input is possible
    """
    config_class = BERT_Cat_Config
    base_model_prefix = "bert_model"

    def __init__(self,
                 cfg) -> None:
        super().__init__(cfg)
        
        self.bert_model = AutoModel.from_pretrained(cfg.bert_model)

        for p in self.bert_model.parameters():
            p.requires_grad = cfg.trainable

        self._classification_layer = torch.nn.Linear(self.bert_model.config.hidden_size, 1)

    def forward(self,
                query_n_doc_sequence):

        vecs = self.bert_model(**query_n_doc_sequence)[0][:,0,:] # assuming a distilbert model here
        score = self._classification_layer(vecs)
        return score

#
# init the model & tokenizer (using the distilbert tokenizer)
#
bertcat_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
bertcat_model = BERT_Cat.from_pretrained("sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
## COLBERT Model

from transformers import AutoTokenizer,AutoModel, PreTrainedModel,PretrainedConfig
from typing import Dict
import torch

class ColBERTConfig(PretrainedConfig):
    model_type = "ColBERT"
    bert_model: str
    compression_dim: int = 768
    dropout: float = 0.0
    return_vecs: bool = False
    trainable: bool = True

class ColBERT(PreTrainedModel):
    """
    ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf
    We use a dot-product instead of cosine per term (slightly better)
    """
    config_class = ColBERTConfig
    base_model_prefix = "bert_model"

    def __init__(self,
                 cfg) -> None:
        super().__init__(cfg)
        
        self.bert_model = AutoModel.from_pretrained(cfg.bert_model)

        for p in self.bert_model.parameters():
            p.requires_grad = cfg.trainable

        self.compressor = torch.nn.Linear(self.bert_model.config.hidden_size, cfg.compression_dim)

    def forward(self,
                query: Dict[str, torch.LongTensor],
                document: Dict[str, torch.LongTensor]):

        query_vecs = self.forward_representation(query)
        document_vecs = self.forward_representation(document)

        score = self.forward_aggregation(query_vecs,document_vecs,query["attention_mask"],document["attention_mask"])
        return score

    def forward_representation(self,
                               tokens,
                               sequence_type=None) -> torch.Tensor:
        
        vecs = self.bert_model(**tokens)[0] # assuming a distilbert model here
        vecs = self.compressor(vecs)

        # if encoding only, zero-out the mask values so we can compress storage
        if sequence_type == "doc_encode" or sequence_type == "query_encode": 
            vecs = vecs * tokens["tokens"]["mask"].unsqueeze(-1)

        return vecs

    def forward_aggregation(self,query_vecs, document_vecs,query_mask,document_mask):
        
        # create initial term-x-term scores (dot-product)
        score = torch.bmm(query_vecs, document_vecs.transpose(2,1))

        # mask out padding on the doc dimension (mask by -1000, because max should not select those, setting it to 0 might select them)
        exp_mask = document_mask.bool().unsqueeze(1).expand(-1,score.shape[1],-1)
        score[~exp_mask] = - 10000

        # max pooling over document dimension
        score = score.max(-1).values

        # mask out paddding query values
        score[~(query_mask.bool())] = 0

        # sum over query values
        score = score.sum(-1)

        return score

#
# init the model & tokenizer (using the distilbert tokenizer)
#
colbert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # honestly not sure if that is the best way to go, but it works :)
colbert_model = ColBERT.from_pretrained("sebastian-hofstaetter/colbert-distilbert-margin_mse-T2-msmarco")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
## BERT-DOT Model

from transformers import AutoTokenizer, AutoModel

# you can switch the model to the original "distilbert-base-uncased" to see that the usage example then breaks and the score ordering is reversed :O
#pre_trained_model_name = "distilbert-base-uncased"
pre_trained_model_name = "sebastian-hofstaetter/distilbert-dot-margin_mse-T2-msmarco"

bertdot_tokenizer = AutoTokenizer.from_pretrained(pre_trained_model_name) 
bertdot_model = AutoModel.from_pretrained(pre_trained_model_name)

In [7]:
import numpy as np
import time
import pandas as pd

In [8]:

def unrolled_to_ranked_result(unrolled_results):
    ranked_result = {}
    for query_id, query_data in unrolled_results.items():
        local_list = []
        # sort the results per query based on the output
        for (doc_id, output_value) in sorted(query_data, key=lambda x: x[1], reverse=True):
            local_list.append(doc_id)
        ranked_result[query_id] = local_list
    return ranked_result

In [9]:
def calculate_metrics_plain(ranking, qrels,binarization_point=1.0,return_per_query=False):
    '''
    calculate main evaluation metrics for the given results (without looking at candidates),
    returns a dict of metrics
    '''

    ranked_queries = len(ranking)
    ap_per_candidate_depth = np.zeros((ranked_queries))
    #coverage_per_candidate_depth = np.zeros((len(global_metric_config["nDCG@"]),ranked_queries))
    rr_per_candidate_depth = np.zeros((len(global_metric_config["MRR+Recall@"]),ranked_queries))
    rank_per_candidate_depth = np.zeros((len(global_metric_config["MRR+Recall@"]),ranked_queries))
    recall_per_candidate_depth = np.zeros((len(global_metric_config["MRR+Recall@"]),ranked_queries))
    ndcg_per_candidate_depth = np.zeros((len(global_metric_config["nDCG@"]),ranked_queries))
    evaluated_queries = 0

    for query_index,(query_id,ranked_doc_ids) in enumerate(ranking.items()):
        
        if query_id in qrels:
            evaluated_queries += 1

            relevant_ids = np.array(list(qrels[query_id].keys())) # key, value guaranteed in same order
            relevant_grades = np.array(list(qrels[query_id].values()))
            sorted_relevant_grades = np.sort(relevant_grades)[::-1]

            num_relevant = relevant_ids.shape[0]
            np_rank = np.array(ranked_doc_ids)
            relevant_mask = np.in1d(np_rank,relevant_ids) # shape: (ranking_depth,) - type: bool

            binary_relevant = relevant_ids[relevant_grades >= binarization_point]
            binary_num_relevant = binary_relevant.shape[0]
            binary_relevant_mask = np.in1d(np_rank,binary_relevant) # shape: (ranking_depth,) - type: bool

            # check if we have a relevant document at all in the results -> if not skip and leave 0 
            if np.any(binary_relevant_mask):
                
                # now select the relevant ranks across the fixed ranks
                ranks = np.arange(1,binary_relevant_mask.shape[0]+1)[binary_relevant_mask]

                #
                # ap
                #
                map_ranks = ranks[ranks <= global_metric_config["MAP@"]]
                ap = np.arange(1,map_ranks.shape[0]+1) / map_ranks
                ap = np.sum(ap) / binary_num_relevant
                ap_per_candidate_depth[query_index] = ap
                
                # mrr only the first relevant rank is used
                first_rank = ranks[0]

                for cut_indx, cutoff in enumerate(global_metric_config["MRR+Recall@"]):

                    curr_ranks = ranks.copy()
                    curr_ranks[curr_ranks > cutoff] = 0 

                    recall = (curr_ranks > 0).sum(axis=0) / binary_num_relevant
                    recall_per_candidate_depth[cut_indx,query_index] = recall

                    #
                    # mrr
                    #

                    # ignore ranks that are out of the interest area (leave 0)
                    if first_rank <= cutoff: 
                        rr_per_candidate_depth[cut_indx,query_index] = 1 / first_rank
                        rank_per_candidate_depth[cut_indx,query_index] = first_rank
            
            if np.any(relevant_mask):
                
                # now select the relevant ranks across the fixed ranks
                ranks = np.arange(1,relevant_mask.shape[0]+1)[relevant_mask]

                grades_per_rank = np.ndarray(ranks.shape[0],dtype=int)
                for i,id in enumerate(np_rank[relevant_mask]):
                    grades_per_rank[i]=np.where(relevant_ids==id)[0]

                grades_per_rank = relevant_grades[grades_per_rank]

                #
                # ndcg = dcg / idcg 
                #
                for cut_indx, cutoff in enumerate(global_metric_config["nDCG@"]):
                    #
                    # get idcg (from relevant_ids)
                    idcg = (sorted_relevant_grades[:cutoff] / np.log2(1 + np.arange(1,min(num_relevant,cutoff) + 1)))

                    curr_ranks = ranks.copy()
                    curr_ranks[curr_ranks > cutoff] = 0 

                    #coverage_per_candidate_depth[cut_indx, query_index] = (curr_ranks > 0).sum() / float(cutoff)

                    with np.errstate(divide='ignore', invalid='ignore'):
                        c = np.true_divide(grades_per_rank,np.log2(1 + curr_ranks))
                        c[c == np.inf] = 0
                        dcg = np.nan_to_num(c)

                    nDCG = dcg.sum(axis=-1) / idcg.sum()

                    ndcg_per_candidate_depth[cut_indx,query_index] = nDCG

    #avg_coverage = coverage_per_candidate_depth.sum(axis=-1) / evaluated_queries
    mrr = rr_per_candidate_depth.sum(axis=-1) / evaluated_queries
    relevant = (rr_per_candidate_depth > 0).sum(axis=-1)
    non_relevant = (rr_per_candidate_depth == 0).sum(axis=-1)

    avg_rank=np.apply_along_axis(lambda v: np.mean(v[np.nonzero(v)]), -1, rank_per_candidate_depth)
    avg_rank[np.isnan(avg_rank)]=0.

    median_rank=np.apply_along_axis(lambda v: np.median(v[np.nonzero(v)]), -1, rank_per_candidate_depth)
    median_rank[np.isnan(median_rank)]=0.

    map_score = ap_per_candidate_depth.sum(axis=-1) / evaluated_queries
    recall = recall_per_candidate_depth.sum(axis=-1) / evaluated_queries
    nDCG = ndcg_per_candidate_depth.sum(axis=-1) / evaluated_queries

    local_dict={}

    for cut_indx, cutoff in enumerate(global_metric_config["MRR+Recall@"]):

        local_dict['MRR@'+str(cutoff)] = mrr[cut_indx]
        local_dict['Recall@'+str(cutoff)] = recall[cut_indx]
        local_dict['QueriesWithNoRelevant@'+str(cutoff)] = non_relevant[cut_indx]
        local_dict['QueriesWithRelevant@'+str(cutoff)] = relevant[cut_indx]
        local_dict['AverageRankGoldLabel@'+str(cutoff)] = avg_rank[cut_indx]
        local_dict['MedianRankGoldLabel@'+str(cutoff)] = median_rank[cut_indx]
    
    for cut_indx, cutoff in enumerate(global_metric_config["nDCG@"]):
        #local_dict['Avg_coverage@'+str(cutoff)] = avg_coverage[cut_indx]
        local_dict['nDCG@'+str(cutoff)] = nDCG[cut_indx]

    local_dict['QueriesRanked'] = evaluated_queries
    local_dict['MAP@'+str(global_metric_config["MAP@"])] = map_score
    
    if return_per_query:
        return local_dict,rr_per_candidate_depth,ap_per_candidate_depth,recall_per_candidate_depth,ndcg_per_candidate_depth
    else:
        return local_dict
        

In [10]:
def load_qrels(path):
    with open(path,'r') as f:
        qids_to_relevant_passageids = {}
        for l in f:
            try:
                l = l.strip().split()
                qid = l[0]
                if float(l[3]) > 0.0001:
                    if qid not in qids_to_relevant_passageids:
                        qids_to_relevant_passageids[qid] = {}
                    qids_to_relevant_passageids[qid][l[2]] = float(l[3])
            except:
                raise IOError('\"%s\" is not valid format' % l)
        return qids_to_relevant_passageids
        

In [11]:
global_metric_config = {
    "MRR+Recall@":[10,20,100,200,1000], # multiple allowed
    "nDCG@":[3,5,10,20,1000], # multiple allowed
    "MAP@":1000, #only one allowed
}

In [12]:
test_config={}
test_config["binarization_point"]=1
test_config["qrels"] = "data/qrels.trec6-8.nocr"

In [14]:
## Load TREC45 documents into a dataframe
docs = pd.read_csv("data/trec45-documents.csv")

  exec(code_obj, self.user_global_ns, self.user_ns)


In [15]:
docnos = docs["DOCNO"]
doctext = docs["TEXT"]
docs = docs.drop(columns=['HEADLINE','GRAPHIC'])

In [16]:
## TREC45 queries for testing (300-450)

query_file = open("data/queries.txt","r").readlines()
queries ={}
for query in query_file:
    lst = query.split(" ",1)
    queries[int(lst[0])] = lst[1].strip("\n")


In [17]:
## Top-100 BM25 documents for all queries -- obtained from Anserini
bm25 = pd.read_csv("data/trec45_indri_kstem_top1000_bm25.out", delimiter=" ",names=['query','Q0','DOCNO','rank','bm25metric','bm25'])

In [18]:
bm25 = bm25[bm25["rank"]<100]
bm25 = bm25.drop(columns=['Q0','bm25metric','bm25'])


In [19]:
bm25['present']=np.where((bm25['DOCNO'].isin(docs["DOCNO"])), 1, 0)
to_rank = bm25[bm25['present']==1]

In [20]:
unique_queries = set(to_rank['query'])
# print(len(unique_queries))

In [21]:
to_rank_merged = to_rank.merge(docs, on=['DOCNO'], how ='left')

In [29]:
to_rank_docs={}
for query_id in unique_queries:
    docs_=list(to_rank_merged[to_rank_merged["query"]==query_id]["TEXT"])
    to_rank_docs[query_id] = docs_


In [30]:
query_text={}
for query_id in unique_queries:
    if query_id in queries.keys():
        query_text[query_id] = queries[query_id]


In [24]:

### 
def make_bert_input(tokenizer,doctext):
    bert_input = []
    for text in doctext:
        bert_input.append(tokenizer(text,return_tensors="pt",max_length=500))
    return bert_input
    
# print(len(bert_input))


In [38]:
def make_main_dict(query_text,model,bert_input,tokenizer,model_name):
    main_dict=[]
    if(model_name == "IDCM"):
        query_input = tokenizer(query_text ,return_tensors="pt",max_length=30,truncation=True)
        for i, inp in enumerate(bert_input):
            score  = model(query_input, inp).squeeze(0)
            lst = (docnos[i],score.item())
            main_dict.append(lst)

    if(model_name=="COLBERT"):
            
        query_input = tokenizer(query_text)

        query_input.input_ids += [103] * 8 # [MASK]
        query_input.attention_mask += [1] * 8
        query_input["input_ids"] = torch.LongTensor(query_input.input_ids).unsqueeze(0)
        query_input["attention_mask"] = torch.LongTensor(query_input.attention_mask).unsqueeze(0)
        for i, inp in enumerate(bert_input):
            score  = model.forward(query_input, inp).squeeze(0)
            lst = (docnos[i],score.item())
            main_dict.append(lst)
            
    return main_dict

In [36]:
start_time = time.time()
query_results_dict={}
for query_id in query_text.keys():
    query = query_text[query_id]
    idcm_input=make_bert_input(idcm_tokenizer,to_rank_docs[query_id])
    idcm_result_lst= make_main_dict(query,idcm_model,idcm_input,idcm_tokenizer,"IDCM")
    query_results_dict[query_id] = idcm_result_lst

idcm_ranked_results = unrolled_to_ranked_result(query_results_dict)
idcm_metrics = calculate_metrics_plain(idcm_ranked_results,load_qrels(test_config["qrels"]),test_config["binarization_point"])
end_time = time.time()
time_taken = end_time - start_time

In [37]:
idcm_metrics

In [None]:

start_time = time.time()
query_results_dict={}
for query_id in query_text.keys():
    query = query_text[query_id]
    colbert_input=make_bert_input(colbert_tokenizer,to_rank_docs[query_id])
    colbert_result_lst= make_main_dict(query,colbert_model,colbert_input,colbert_tokenizer,"COLBERT")
    query_results_dict[query_id] = colbert_result_lst
colbert_ranked_results = unrolled_to_ranked_result(query_results_dict)
colbert_metrics = calculate_metrics_plain(colbert_ranked_results,load_qrels(test_config["qrels"]),test_config["binarization_point"])
end_time = time.time()
time_taken = end_time - start_time


In [39]:
colbert_metrics

In [None]:
## BERTDOT and BERTCAT comparison 

# bertdot_input=make_bert_input(bertdot_tokenizer,texts)
# bertdot_main_dict= make_main_dict(query_text,bertdot_model,bertdot_input)
# bertdot_ranked_results = unrolled_to_ranked_result(bertdot_main_dict)
# bertdot_metrics = calculate_metrics_plain(bertdot_ranked_results,load_qrels(test_config["qrels"]),test_config["binarization_point"])

# colbert_input=make_bert_input(colbert_tokenizer,texts)
# colbert_main_dict= make_main_dict(query_text,colbert_model,colbert_input)
# colbert_ranked_results = unrolled_to_ranked_result(colbert_main_dict)
# colbert_metrics = calculate_metrics_plain(colbert_ranked_results,load_qrels(test_config["qrels"]),test_config["binarization_point"])
