In [1]:
from eli5_utils import *

#### Load ELI5 dataset

In [2]:
eli5_dbuilder = ELI5NLP(data_dir='eli5')
eli5_dbuilder.download_and_prepare()

eli5_train = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)
eli5_valid = eli5_dbuilder.as_dataset(split=nlp.splits.Split.VALIDATION)
eli5_test = eli5_dbuilder.as_dataset(split=nlp.splits.Split.TEST)

In [3]:
# Let's look at an example
eli5_train[1]

{'q_id': '2lojul',
 'title': "Why are different tiers (regular < mid < premium) of gas' prices almost always 10 cents different?",
 'selftext': "I've noticed that the difference in price between regular gas and midrange, and between midrange and premium, is almost always 10 cents. This seems to hold true no matter what the price for regular gas. This doesn't seem to make sense, as the difference between $2 and $2.10 and the difference between $4 and $4.10 /gal are proportionally very different. Is this just an arbitrary convention that undermines arguments of a rational basis for gasoline prices?",
 'answers': {'a_id': ['', '', ''],
  'text': ['As someone who uses quality Premium, I wish this was true.',
   "The difference is in how it burns though is what's critical for you as the end consumer. I drive a forced induction car, so air coming into my engine is compressed before it enters the cylinder where it's further compressed by the piston. The Regular, Mid, and Premium gas are rated

#### Load KILT-Wikipedia 100 word passages

In [4]:
kilt_snippets_dbuilder = KiltSnippets(data_dir='kilt_snippets_100w')
kilt_snippets_dbuilder.download_and_prepare()
wiki_passages = kilt_snippets_dbuilder.as_dataset(split=nlp.splits.Split.TRAIN)

In [2]:
wiki_passages = nlp.load_dataset(path="/home/yacine/Code/nlp/datasets/wiki_snippets", name="wikipedia_en_100_0")['train']
wiki_passages[123]

{'_id': '{"nlp_id": 26, "wiki_id": "The_Wasps", "sp": 38, "sc": 103, "ep": 44, "ec": 12}',
 'nlp_id': 26,
 'wiki_id': 'The_Wasps',
 'start_paragraph': 38,
 'start_character': 103,
 'end_paragraph': 44,
 'end_character': 12,
 'article_title': 'The Wasps',
 'section_title': 'Some events that influenced The Wasps & Megara: a neighbour and historically a rival to Athens, it is mentioned in line 57 as the reputed origin of comic drama.',
 'passage_text': 'was produced), Old Comedy needs commentators to explain its abstruse references, in the same way that a banquet needs wine waiters. Here is the wine list for The Wasps as supplied by modern scholars.  Places   Megara: a neighbour and historically a rival to Athens, it is mentioned in line 57 as the reputed origin of comic drama. Law Courts: Athens had ten law courts in 422 BC, of which these three are mentioned here by name: The New Court in line 120, The Court at Lykos in line 389  and The Odeion in line 1109. Asclepieia:'}

In [2]:
wiki_passages = nlp.load_dataset(path="/home/yacine/Code/nlp/datasets/wiki_snippets", name="wiki40b_en_100_0")['train']
es_client = Elasticsearch([{'host': 'localhost', 'port': '9200'}])
make_es_index_snippets(es_client, wiki_passages, index_name='english_wiki40b_snippets_100w')

100%|█████████▉| 17553501/17553713 [1:46:45<00:00, 2151.77docs/s]

Indexed 17553713 documents


100%|██████████| 17553713/17553713 [1:47:02<00:00, 2151.77docs/s]

In [4]:
es_client.indices.delete('english_wiki40b_snippets_100w')

{'acknowledged': True}

#### Make ElasticSearch Index

In [57]:
es_client = Elasticsearch([{'host': 'localhost', 'port': '9200'}])
if not es_client.indices.exists('english_wiki_kilt_snippets_100w'):
    make_es_index_snippets(es_client, wiki_passages, index_name='english_wiki_kilt_snippets_100w')

In [60]:
# Let's test the ElasticSearch sparse retriever
question = eli5_train[1]['title']
support_doc, hit_lst = query_es_index(question, es_client, n_results=10,
                                      index_name='english_wiki_kilt_snippets_100w')

print('Question:')
print(question)
print('\n -- Sparse retriever fetching information from:\n')
for res in hit_lst:
    print('{} || {} || {:.2f}'.format(res['article_title'], res['section_title'].strip(), res['score']))
    # print(sc)
print('\n------------\n')
print('Support document')
print(support_doc)

Question:
Why are different tiers (regular < mid < premium) of gas' prices almost always 10 cents different?

Sparse retriever fetching information from:

2007 Gasoline Rationing Plan in Iran || Gas rationing plan. || 80.05
Filling station || Octane. || 74.96
Winn-Dixie || Brands. || 66.28
Lawrence Lessig || Internet and computer activism. -- Net neutrality. || 65.92
Maryland Electric Deregulation || The perfect storm. || 64.73
Cable television || History in North America. || 63.45
Pay television || Pricing and packaging. || 62.68
International Fairtrade Certification Mark || History. || 62.52
Marcellus natural gas trend || Economic effects. -- Employment. || 62.23
Health insurance marketplace || History. -- Comparable tiers of plans. || 60.69

------------

Support document
<P> international pressure related to its nuclear program.
 Based on the rationing plan, each private car received 120 liters per month at about 10 cents per liter. The price for non-rationed gasoline in November 2

#### Make QA-retriever dense Index

In [5]:
# Load pre-trained embedding model
r_tokenizer, r_qa_embedder = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-512_A-8",
    from_file="retriever_models/embed_eli5_qa_512_4.pth",
    device="cuda:0"
)

In [6]:
if not os.path.isfile('kilt_passages_reps.dat'):
    make_qa_dense_index(r_qa_embedder, r_tokenizer,
                        wiki_passages,
                        batch_size=512,
                        index_name='kilt_passages_reps.dat',
                        device='cuda:0')

passage_reps = np.memmap('kilt_passages_reps.dat', dtype='float16', mode='r', shape=(wiki_passages.num_rows, 128))

In [7]:
res = faiss.StandardGpuResources()
index_flat = faiss.IndexFlatIP(128)
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)
gpu_index_flat.add(passage_reps)

In [8]:
question = eli5_train[1]['title']
support_doc, hit_lst = query_qa_dense_index(question,
                                            r_qa_embedder, r_tokenizer,
                                            wiki_passages, gpu_index_flat,
                                            n_results=10)

print('Question:')
print(question)
print('\n -- Dense retriever fetching information from:\n')
for res in hit_lst:
    print('{} || {} || {:.2f}'.format(res['article_title'], res['section_title'].strip(), res['score']))
    # print(sc)
print('\n------------\n')
print('Support document')
print(support_doc)

Question:
Why are different tiers (regular < mid < premium) of gas' prices almost always 10 cents different?

 Dense retriever fetching information from:

Gasoline || Use and pricing. -- United States. || 23.78
Pak'nSave || Fuel discounts. || 23.73
Liquefied natural gas || LNG pricing. -- Price review. || 23.22
Octane rating || Effects. || 23.02
Filling station || Fuel prices. -- North America. || 22.55
Gas meter || Heating value. || 22.53
Natural gas || Energy content, statistics, and pricing. -- Canada. || 22.52
Natural gas || Energy content, statistics, and pricing. -- United States. || 22.29
Filling station || Marketing. -- North America. || 22.25
Gasoline || Use and pricing. -- United States. || 22.23

------------

Support document
<P> rates, including the federal taxes as of October 2018, are found in Pennsylvania (77.1¢/gal), California (73.93¢/gal), and Washington (67.8¢/gal). 
 About 9 percent of all gasoline sold in the U.S. in May 2009 was premium grade, according to the En

#### Pre-computing the docs to save GPU

In [24]:
batch_size = 256

all_docs = []
st_time = time()
for b in range(math.ceil(eli5_train.num_rows / batch_size)):
    questions = eli5_train[b * batch_size:(b+1) * batch_size]['title']
    docs, res_lst = batch_query_qa_dense_index(questions,
                                               r_qa_embedder, r_tokenizer,
                                               wiki_passages, gpu_index_flat,
                                               n_results=10)
    all_docs += [(k, d, r) for k, d, r in zip(eli5_train[b * batch_size:(b+1) * batch_size]['q_id'], docs, res_lst)]
    if b % 100 == 0:
        print(b, time() - st_time)

json.dump(all_docs, open('eli5_train_precomputed_dense_docs.json', 'w'))

all_docs = []
st_time = time()
for b in range(math.ceil(eli5_valid.num_rows / batch_size)):
    questions = eli5_valid[b * batch_size:(b+1) * batch_size]['title']
    docs, res_lst = batch_query_qa_dense_index(questions,
                                               r_qa_embedder, r_tokenizer,
                                               wiki_passages, gpu_index_flat,
                                               n_results=10)
    all_docs += [(k, d, r) for k, d, r in zip(eli5_valid[b * batch_size:(b+1) * batch_size]['q_id'], docs, res_lst)]
    if b % 100 == 0:
        print(b, time() - st_time)

json.dump(all_docs, open('eli5_valid_precomputed_dense_docs.json', 'w'))

all_docs = []
st_time = time()
for b in range(math.ceil(eli5_test.num_rows / batch_size)):
    questions = eli5_test[b * batch_size:(b+1) * batch_size]['title']
    docs, res_lst = batch_query_qa_dense_index(questions,
                                               r_qa_embedder, r_tokenizer,
                                               wiki_passages, gpu_index_flat,
                                               n_results=10)
    all_docs += [(k, d, r) for k, d, r in zip(eli5_test[b * batch_size:(b+1) * batch_size]['q_id'], docs, res_lst)]
    if b % 100 == 0:
        print(b, time() - st_time)

json.dump(all_docs, open('eli5_test_precomputed_dense_docs.json', 'w'))

0 0.603712797164917
100 29.4765465259552
200 52.303476095199585
300 73.91007280349731
400 95.56331825256348
500 117.14646506309509


KeyboardInterrupt: 

#### Making retrieval test set

In [85]:
test_keys = [(i, d['q_id'])
             for i, d in enumerate(eli5_test) if len([s for s in d['answers']['score'] if s > 2]) > 2][-5000:]

retriever_test = [
    {
        'eli5_test_id': i,
        'q_id': k,
        'question': eli5_test[i]['title'],
        'answers': eli5_test[i]['answers']['text'][:3],
    }
    for i, k in test_keys]

st_time = time()
for ct, r_example in enumerate(retriever_test):
    _, es_q_hits = query_es_index(
        r_example['question'], es_client,
        n_results=100,
        index_name='english_wiki_kilt_snippets_100w'
    )
    r_example['question_suggested_passages_sparse'] = es_q_hits[:]
    _, ds_q_hits = query_qa_dense_index(
    r_example['question'],
    r_qa_embedder, r_tokenizer,
    wiki_passages, gpu_index_flat,
    n_results=50
    )
    r_example['question_suggested_passages'] = ds_q_hits[:]
    ds_a_hits_ls = [[] for _ in r_example['answers']]
    for i, answer in enumerate(r_example['answers']):
        a_tab = answer.split()
        for ia in range(max(len(a_tab) // 64, 1)):
            a_span = ' '.join(a_tab[ia*64:ia*64 + 128])
            _, ds_a_hits = query_qa_dense_index_nn(
                r_example['question'],
                r_qa_embedder, r_tokenizer,
                wiki_passages, gpu_index_flat,
                n_results=50
            )
            ds_a_hits_ls[i] += ds_a_hits
        # de-duplicate
        ds_a_hits_ls[i] = sorted(ds_a_hits_ls[i], key=lambda x:x['score'], reverse=True)
        ds_a_hits_ls[i] = [res for j, res in enumerate(ds_a_hits_ls[i]) \
                               if res['passage_id'] not in [r['passage_id'] for r in ds_a_hits_ls[i][:j]]][:50]
    r_example['answer_suggested_passages'] = [ds_a_hits[:] for ds_a_hits in ds_a_hits_ls]
    if ct % 100 == 0:
        print(ct, time() - st_time)


json.dump(retriever_test, open('eli5_retriever_test_set_with dense_suggestions.json', 'w'))