In [14]:
from __future__ import absolute_import, division, print_function

import argparse
import logging
import os
import sys
import random
from tqdm import tqdm_notebook as tqdm
from tqdm import tnrange as trange

import numpy as np

import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss


if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

from pytorch_transformers import BertForNextSentencePrediction, BertTokenizer

logger = logging.getLogger(__name__)


sys.path.append("../scripts/run_classifier_dataset_utils.py")


In [7]:
task_name = "msmarco"
do_train = False
do_eval = True
do_lower_case = True
data_dir = "/ssd2/arthur/TREC2019/data/"
bert_model = "bert-base-uncased"
max_seq_length = 512 
train_batch_size = 32
learning_rate = 2e-5 
num_train_epochs = 3.0 
output_dir = os.path.join(data_dir, "models")
overwrite_output_dir = False
eval_batch_size = 128

local_rank = -1

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

In [9]:
model = BertForNextSentencePrediction.from_pretrained(output_dir)
model = torch.nn.DataParallel(model)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=do_lower_case)
model.to(device)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [19]:
eval_examples[0]

<run_classifier_dataset_utils.InputFeatures at 0x7f66a2071390>

In [26]:
%reload_ext autoreload
%autoreload 2

from run_classifier_dataset_utils import load_dataset
eval_dataloader, eval_examples = load_dataset(task_name, bert_model, max_seq_length, data_dir, tokenizer, eval_batch_size, eval=True, return_examples=True)
assert eval_examples[0].guid == "dev-174249-D3126537"

Reading input tsv:   3%|▎         | 58091/2122814 [00:08<04:47, 7187.93it/s]
creating examples...: 58091it [00:00, 328399.81it/s]


In [27]:
eval_examples[0].guid

'dev-174249-D3126537'

In [18]:
model.eval()
eval_loss = 0
nb_eval_steps = 0
preds = []
out_label_ids = None
scores = []
classes = []

softmax = torch.nn.Softmax(dim=1)
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    label_ids = label_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids, token_type_ids=segment_ids, next_sentence_label=label_ids)
        predictions = outputs[1]
        eval_loss += outputs[0]
        
        scores += list(predictions[:, 1].cpu().detach().numpy())
        
        classes += list(torch.argmax(predictions, dim=1).cpu().numpy())

        nb_eval_steps+=1

HBox(children=(IntProgress(value=0, description='Evaluating', max=391, style=ProgressStyle(description_width='…






In [8]:
from IPython.core.debugger import set_trace

from collections import defaultdict
#load bm25 scores. Fnal score is a combination THIS IS BUGGY
bm25_scores = {}
bm25_run_file = "/ssd2/arthur/terrier-core/var/results/run.msmarco_docs.bm25.res"
guids = []
last_topic = None
normalized_scores = []
ordered_topics = []
scores_per_topic = defaultdict(lambda:[])


with open(bm25_run_file, 'r') as inf:
    for counter, line in enumerate(inf):
        [topic_id, _, doc_id, _, score, _] = line.split()
        if topic_id not in ordered_topics:
            ordered_topics.append(topic_id)
        scores_per_topic[topic_id].append((doc_id, score))
#normalize
for _id in scores_per_topic:
    scores = np.asarray([float(x[1]) for x in scores_per_topic[_id]])
    normalized_scores = (scores - np.min(scores))/np.ptp(scores)
    for (did, _), score in zip(scores_per_topic[_id], normalized_scores):
        guid = "{}-{}".format(_id, did)
        bm25_scores[guid] = score

In [20]:
## from IPython.core.debugger import set_trace

import subprocess

runs_format = "{} Q0 {} {} {} BERT_BM25\n" #topic_id, doc_id, ranking, score

n_alphas = 50
for a in range(0, n_alphas):
    alpha = a/n_alphas
    beta = 1-alpha

    run_file = os.path.join("/ssd2/arthur/terrier-core/var/results/bert-{}.res".format(alpha))

    topic_results = []
    last_topic = eval_examples[0].guid.split("-")[1]
    with open(run_file, 'w') as outf, open(bm25_run_file) as inf:
        for counter, (example, score) in enumerate(zip(eval_examples, scores)):
            [_, topic_id, doc_id] = example.guid.split("-")
            if topic_id != last_topic:
                last_topic = topic_id
                topic_results.sort(key = lambda x:x['score'], reverse=True)
                for rank, topic in enumerate(topic_results):
                    outf.write(runs_format.format(topic['topic_id'], topic['doc_id'], rank, topic['score']))
                topic_results = []
            topic_results.append({'topic_id': topic_id, 'doc_id': doc_id, 'score': alpha*score+beta*bm25_scores[f"{topic_id}-{doc_id}"]})
            last_topic = topic_id
        for rank, topic in enumerate(topic_results):
            outf.write(runs_format.format(topic['topic_id'], topic['doc_id'], rank, topic['score']))

# eval script:
cmd = "/ssd2/arthur/terrier-core/bin/terrier batchevaluate -f -q {}".format(os.path.join(data_dir, "msmarco-docdev-qrels.tsv"))
output = subprocess.run(cmd.split(), capture_output=True)
lines = output.stdout.decode("utf-8").split("\n")[3:-1]
max_score = 0.0
for i, j in list(zip(lines[:-1], lines[1:]))[::2]:
    alpha = i.split("-")[-1].split(".res")[0]
    score = float(j.split(":")[-1])
    print(alpha, score)
    if score > max_score:
        max_score = score
        best_alpha = alpha
print(best_alpha, max_score)

0.0 0.2735
0.02 0.2926
0.04 0.3067
0.06 0.3073
0.08 0.3128
0.1 0.3327
0.12 0.3426
0.14 0.3357
0.16 0.3561
0.18 0.3592
0.2 0.3581
0.22 0.3592
0.24 0.354
0.26 0.3511
0.28 0.3493
0.3 0.3496
0.32 0.3475
0.34 0.3484
0.36 0.3438
0.38 0.3371
0.4 0.3353
0.42 0.3335
0.44 0.3258
0.46 0.3023
0.48 0.2934
0.5 0.2816
0.52 0.2664
0.54 0.2586
0.56 0.2561
0.58 0.2549
0.6 0.2402
0.62 0.2363
0.64 0.2144
0.66 0.2123
0.68 0.207
0.7 0.204
0.72 0.2022
0.74 0.2001
0.76 0.1981
0.78 0.1948
0.8 0.184
0.82 0.1779
0.84 0.1654
0.86 0.1578
0.88 0.1574
0.9 0.1535
0.92 0.1509
0.94 0.1461
0.96 0.145
0.98 0.1442
0.18 0.3592


In [None]:
lines[0]

In [None]:
26.01846062624611

In [None]:
example.guid

In [None]:
[x for x in topic_results if x['doc_id']=="D3240836"]