In [None]:
# Make this change to allow code DensePhrases to run from a jupyter notebook
# L227 in DPhrases/options.py:
#     opt, unknown = self.parser.parse_known_args()  # opt = self.parser.parse_args()

In [45]:
# Set environment variables

%env BASE_DIR=../
%env DATA_DIR=../densephrases-data
%env SAVE_DIR=../outputs
%env CACHE_DIR=../cache

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

from integrate import *

env: BASE_DIR=../
env: DATA_DIR=../densephrases-data
env: SAVE_DIR=../outputs
env: CACHE_DIR=../cache
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# EDIT THIS: Set inference parameters

params = {
    "top_k": 5,
    "use_large_index": True,
    "strip_qmark": False,
    "strip_qword1": True,
    "strip_qword2": False,
    "strip_qword_mode": "all",  # first / all
    "prepend_hop_phrase": False,
    "retrieval_unit": "phrase",  # First hop only: phrase / sentence / paragraph
    "single_hop": True,
    "mult_path_scores": True
}

batch_size = 100

In [3]:
# Load the DensePhrases module
print("Loading DensePhrases module...")
model = load_densephrase_module(load_dir=load_dir, 
                                dump_dir=dump_dir, 
                                index_name=idx_name.replace('_small', ('' if params["use_large_index"] \
                                                                       else '_small')), 
                                device=device)

Loading DensePhrases module...
This could take up to 15 mins depending on the file reading speed of HDD/SSD
Loading DensePhrases Completed!


In [35]:
# Read questions and answers from data_path
queries = read_queries(data_path)  #, bridge_only=False, filter_yes_no=True)
questions = get_key(queries, 'question')
answers = get_key(queries, 'answer')
answers = [answer[0] for answer in answers]  # flattening

# Setup function arguments based on the parameters
method = 'pre' if params["prepend_hop_phrase"] else 'post'
top_k = params["top_k"]
ret_unit1 = params["retrieval_unit"]
strip_ques1 = params["strip_qmark"]
strip_prompt1 = params["strip_qword1"]
strip_ques2 = params["strip_qmark"]
strip_prompt2 = params["strip_qword2"]
strip_prompt_mode = params["strip_qword_mode"]
single_hop = params["single_hop"]
mult_path_scores = params["mult_path_scores"]

Loading data from /gypsum/scratch1/dagarwal/multihop_dense_retrieval/data/hotpot/hotpot_qas_val.json
Read 5918 questions in 7405 total questions.


In [10]:
DEBUG=False
if DEBUG:
    ques = questions[:20]
else:
    ques = questions
    
# Run batched inference
results = []
print("Running batched multi-hop inference...")
for i in tqdm(range(0, len(ques), batch_size)):
    batch_results = run_batch_inference(model, ques[i:i + batch_size], strip_ques1, strip_prompt1,
                                        strip_ques2, strip_prompt2, ret_unit1, ret_unit2, ques_terms, method,
                                        strip_prompt_mode, answers=answers[i:i + batch_size], write=False, 
                                        top_k=top_k, silent=True, single_hop=single_hop, mult_scores=mult_path_scores)
    results += batch_results

Running batched multi-hop inference...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [16:10<00:00, 13.87s/it]


In [57]:
# Write predictions to disk
def write_preds(results, params=None, prefix=None, single_hop=None):
    run_id = __import__("calendar").timegm(__import__("time").gmtime())
    
    name = "preds" if prefix is None else prefix
    if prefix is None and single_hop is not None:
        name = 'predictions' if not single_hop else 'singlehop'
    out_file = f'{name}_{run_id}.json'
    
    with open(out_file, 'w') as fp:
        json.dump(results, fp, indent=4)
    print(f"Predictions saved at {out_file}")
    
    if params is not None:
        meta_out_file = out_file.replace('.json', '_meta.json')
        with open(meta_out_file, 'w') as fp:
            json.dump(params, fp, indent=4)
        print(f"Run metadata saved at {meta_out_file}")

In [74]:
# Construct oracle data

import unicodedata
def normalize(text):
    """Resolve different type of unicode encodings."""
    return unicodedata.normalize('NFD', text)

with open('../densephrases-data/hotpotqa/hotpot_dev_firsthop.json', 'r') as handle:
    hotpot_dev = json.load(handle)

oracle_sent_questions = []
oracle_title_questions = []
oracle_title_sent_questions = []
oracle_all_sent_questions = []
for h in hotpot_dev['data']:
    q_sent_set = []
    q_title_set = []
    q_title_sent_set = []
    q_all_sent_set = normalize(h['question'])
    if h['type'] != 'bridge':
        continue
    for i in range(len(h['answers'])):
        q_sent_set.append(normalize(h['question']) + " " + normalize(h['answers'][i]))
        q_title_set.append(normalize(h['question']) + " " + normalize(h['titles'][i]))
        q_title_sent_set.append(normalize(h['question']) + " " + normalize(h['titles'][i]) + " " + normalize(h['answers'][i]))
        q_all_sent_set += " " + normalize(h['answers'][i])
    oracle_sent_questions.append(list(set(q_sent_set)))
    oracle_title_questions.append(list(set(q_title_set)))
    oracle_title_sent_questions.append(list(set(q_title_sent_set)))
    oracle_all_sent_questions.append(q_all_sent_set)

In [58]:
# Oracle predictions - sentence-version
sent_results = []
for i in tqdm(range(len(questions))):
    sent_results.append(run_oracle_inference(model, questions[i], oracle_sent_questions[i], answers[i]))

write_preds(sent_results, prefix='oracle_sent_preds')

Predictions saved at oracle_sent_preds_1648199883.json


In [71]:
# Oracle predictions - title-version
title_results = []
for i in tqdm(range(len(questions))):
    title_results.append(run_oracle_inference(model, questions[i], oracle_title_questions[i], answers[i]))
write_preds(title_results, prefix='oracle_title_preds')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5918/5918 [14:43<00:00,  6.70it/s]


Predictions saved at oracle_title_preds_1648201162.json


In [73]:
# Oracle predictions - title+sent-version
title_sent_results = []
for i in tqdm(range(len(questions))):
    title_sent_results.append(run_oracle_inference(model, questions[i], oracle_title_sent_questions[i], answers[i]))
write_preds(title_sent_results, prefix='oracle_title_sent_preds')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5918/5918 [15:59<00:00,  6.17it/s]


Predictions saved at oracle_title_sent_preds_1648202141.json


In [78]:
# Oracle predictions - all_sents-version
all_sent_results = []
for i in tqdm(range(len(questions))):
    all_sent_results.append(run_oracle_inference(model, questions[i], [oracle_all_sent_questions[i]], answers[i]))
write_preds(all_sent_results, prefix='oracle_all_sent_preds')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5918/5918 [14:57<00:00,  6.59it/s]


Predictions saved at oracle_all_sent_preds_1648226462.json
