## Install libraries

In [26]:
!pip install pytorch-pretrained-bert
!pip install livelossplot
!pip install nvidia-ml-py3
!pip install unidecode

[33mYou are using pip version 10.0.1, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
[33mYou are using pip version 10.0.1, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
[33mYou are using pip version 10.0.1, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
[33mYou are using pip version 10.0.1, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


## Import libraries

In [2]:
import pandas as pd
import numpy as np
import os
import json
import unidecode
import re
import torch

from tqdm.auto import tqdm 
from tqdm import tqdm_notebook

from pytorch_pretrained_bert import BertTokenizer, BertModel
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME, BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  whitespace_tokenize)

## Helper Functions

In [3]:
# function to get the IDs of the previous queries of a query in a session 
def get_lower_ids(session_df, query_id):
    session_id = int(query_id.split('_')[0])
    current_id = int(query_id.split('_')[1])
    all_ids = [int(x.split('_')[1]) for x in session_df['query_id'].tolist()]
    lower_ids = [x for x in all_ids if x < current_id]
    lower_ids = [str(session_id) + '_' + str(x) for x in lower_ids]
    return lower_ids

In [4]:
# function that strips all non-alphanumeric characters
def remove_non_alphanumeric(text):
    text = unidecode.unidecode(str(text))
    text = re.sub(r'[^a-zA-Z0-9]', ' ', text)
    return text

In [5]:
# function that returns a list of segment ids based on indexed tokens (BERT)
def get_segment_ids_from_index_tokens(indexed_tokens):
    segment_ids = []
    sep = False
    for i, token in enumerate(indexed_tokens):
        if token == 102:
            sep = True
        if sep:
            segment_ids.append(1)
        else:
            segment_ids.append(0)
    return segment_ids

In [6]:
def run_bert(data):
    activations = []
    for i in tqdm_notebook(range(len(data))):
        # convert inputs to PyTorch tensors
        tokens = data.iloc[i]['indexed_tokens']
        segment_ids = data.iloc[i]['segment_ids']
        
        # make sure the input fits
        token_size_diff = len(tokens) - 512
        if token_size_diff > 0:
            tokens = [tokens[0]] + tokens[token_size_diff:]
            segment_ids = [segment_ids[0]] + segment_ids[token_size_diff:]

        tokens_tensor = torch.tensor([tokens])
        segments_tensors = torch.tensor([segment_ids])

        # set everything to run on GPU
        tokens_tensor = tokens_tensor.to('cuda')
        segments_tensors = segments_tensors.to('cuda')

        with torch.no_grad():
            prediction = bertmodel(tokens_tensor, segments_tensors) 
            activations.append(prediction)

    data['pooled_output'] = activations
    return data

## Load Data

CHANGE THE FOLLOWING FILENAMES

In [7]:
query_subset_filename = 'queries.dev.small.tsv'
anserini_output_filename = 'run.dev.small.tsv'
output_filename = 'bert_run_development_top1000.tsv'

DO NOT CHANGE THE FOLLOWING PATHS

In [8]:
models_dir = "../data/models/"
msmarco_dir = "../data/msmarco_files/"
anserini_output_dir = "../data/anserini_output/"
output_dir = "../data/output/"

In [9]:
# MSMARCO collection
msmarco_collection = pd.read_csv(msmarco_dir + 'collection.tsv',delimiter='\t',encoding='utf-8', header=None)
msmarco_collection.columns = ['passage_id', 'passage']

In [10]:
query_subset = pd.read_csv(msmarco_dir + query_subset_filename,delimiter='\t',encoding='utf-8', header=None)
query_subset.columns = ['query_id', 'query']

In [11]:
query_anserini_output = pd.read_csv(anserini_output_dir + anserini_output_filename,delimiter='\t',encoding='utf-8', header=None)
query_anserini_output.columns = ['query_id', 'passage_id', 'bm25_rank']

In [24]:
top1000_query_ids = pd.DataFrame(list(np.unique(query_anserini_output['query_id'].tolist())))
top1000_query_ids.columns = ['query_id']

## Make BERT DataFrame

In [27]:
tqdm.pandas()
bert_df = top1000_query_ids.copy()
bert_df = bert_df.merge(query_anserini_output,how='left',on=['query_id'])
bert_df = bert_df.merge(query_subset,how='left',on=['query_id'])
bert_df = bert_df.merge(msmarco_collection,how='left',on=['passage_id'])
bert_df['query'] = bert_df['query'].progress_apply(lambda x: remove_non_alphanumeric(x.lower()))
tqdm.pandas()
bert_df['passage'] = bert_df['passage'].progress_apply(lambda x: remove_non_alphanumeric(x.lower()))
bert_df['input_text'] = "[CLS] " + bert_df['query'] +" [SEP] " + bert_df['passage'] + " [SEP]"

HBox(children=(FloatProgress(value=0.0, max=6974598.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=6974598.0), HTML(value='')))




## Load BERT Model

In [28]:
bertmodel = BertForSequenceClassification.from_pretrained('bert-base-uncased', 2)
bertmodel.load_state_dict(torch.load(models_dir + 'fine_tuned_bert_base_uncased'))

bertmodel.eval()
bertmodel.to('cuda')

100%|██████████| 407873900/407873900 [00:18<00:00, 21902723.48B/s]


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
   

In [29]:
tqdm.pandas()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

100%|██████████| 231508/231508 [00:00<00:00, 1140419.72B/s]


## Run BERT

In [None]:
bert_df['indexed_tokens'] = bert_df.progress_apply(lambda row: tokenizer.convert_tokens_to_ids(tokenizer.tokenize(row['input_text'])), axis=1)

HBox(children=(FloatProgress(value=0.0, max=6974598.0), HTML(value='')))




In [31]:
bert_df['segment_ids'] = bert_df.progress_apply(lambda row: get_segment_ids_from_index_tokens(row['indexed_tokens']), axis=1)

HBox(children=(FloatProgress(value=0.0, max=6974598.0), HTML(value='')))




In [33]:
bert_df.to_csv(output_dir + "bert_msmarco_leaderboard_df",sep='\t',header=False,index=False)

In [32]:
output_df = run_bert(bert_df)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  app.launch_new_instance()


HBox(children=(FloatProgress(value=0.0, max=6974598.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:
output_df.to_csv(output_dir + output_filename,sep="\t", header=False,index=False)

In [None]:
output_df['score_bert'] = output_df.progress_apply(lambda row: row['pooled_output'].data[0][1].item(), axis=1)
output_df = output_df.drop(columns=['input_text', 'indexed_tokens', 'segment_ids', 'pooled_output'])

In [None]:
output_df["bert_rank"] = output_df.groupby("query_id")["score_bert"].rank(ascending=0,method='dense')
output_df["bert_rank"] = output_df['bert_rank'].astype(int)

In [None]:
output_df.to_csv(output_dir + output_filename,sep="\t", header=False,index=False)