# Long Form Question Answering with ELI5 and Wikipedia  

---  

### Table of Contents  

1. [Introduction](#intro)  
    a. [Preliminaries](#prelims)
2. [Task and Data Description](#task_description)  
    a. [Note on Data and Biases](#reddit_biases)
3. [Retrieving Support Documents](#retrieval)  
    a. [Sparse Retrieval with ElasticSearch](#elasticsearch)  
    b. [Training a Dense Retriever with ELI5 and in-batch Negatives](#dense_train)  
    c. [Using a Trained Dense Retriever](#dense_use)  
    d. [Retriever Evaluation](#dense_eval)  
4. [Answer Generation Model](#generation)  
    a. [Conditional Generation with Seq2seq Models](#seq2seq_presentation)  
    b. [Fine-Tuning Seq2seq Models](#seq2seq_train)  
5. [Conclusion](#conclusion)  


---

<img src="images/choco_bis.svg" width="900" align="center"/>  


## Introduction
<a id='intro'></a>

Imagine that you are taken with a sudden desire to understand **how the fruit of a tropical tree gets transformed into chocolate bars**, or want to understand **the role of fever in the human body's immune response**: how would you go about finding that information?

If your specific question has already been asked and answered clearly and succintly on one of the many question answering platforms available on the Internet (such as [**Quora**](https://www.quora.com/How-is-chocolate-made), [**Reddit**](https://www.reddit.com/user/ex_5_libris/comments/9c8gb1/chocolate_how_chocolate_is_made/), or [**Yahoo Answers**](https://answers.yahoo.com/question/index?qid=20070615082202AArsYN1)), you're in luck: modern search engines will probably take you to that pre-existing answer pretty reliably in a matter of a few clicks.  

Otherwise, the process will be a little more involved. You will likely have to collect relevant information from a variety of sources, figure out how these pieces of knowledge fit together in relation to your query, and synthetize a narrative that answers your initial question.

Now, wouldn't it be great if your computer could do all of that for you: **gather** the right sources, **synthetize** the information, and **write up** an easy-to-read summary of the relevant points? Such a system isn't quite available yet, at least not one that can provide *reliable* information in its summary. However, a number of recent advances in natural language understanding and generation have made working toward solving this problem much easier! These advances include progress in the pre-training (e.g. [BART](https://arxiv.org/abs/1910.13461), [T5](https://arxiv.org/abs/1910.10683)) and evaluation (e.g. for [factuality](https://arxiv.org/abs/2004.04228)) of sequence-to-sequence models for conditional text generation, new ways to use language understanding models to find information in Wikipedia (e.g. [REALM](https://kentonl.com/pub/gltpc.2020.pdf), [DPR](https://arxiv.org/abs/2004.04906)), and new [training datasets](https://arxiv.org/abs/1907.09190).

**In this notebook,** we show how we can take advantage of some of these recent works to train a **long form question answering** system which takes in a question, fetches 10 relevant passages from a [Wikipedia snapshot](https://www.aclweb.org/anthology/2020.lrec-1.297/), and writes a multi-sentence answer based on the question and retrieved passages. Follow along to learn about the steps involved and read some background on the state of the art for some related tasks, or go straight to the:  
## [**Live Demo!**](http://35.226.96.115:8080/)  
(And don't forget to scroll down on the left sidebar to show all of the generation options!)

### Preliminaries  
<a id='prelims'></a>

The implementation presented here relies on the [HuggingFace](https://huggingface.co/) [🤗transformers](https://github.com/huggingface/transformers) and [🤗nlp](https://github.com/huggingface/nlp) libraries. Wikipedia indexing relies on [ElasticSearch](https://www.elastic.co/elasticsearch) with its [python bindings](https://github.com/elastic/elasticsearch-py) for the sparse version, and [faiss](https://github.com/facebookresearch/faiss/) for the dense version. You can get all of these by running:
> pip install elasticsearch  
> pip install faiss_gpu  
> pip install nlp  
> pip install transformers  
>  
> wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.7.1-linux-x86_64.tar.gz  
> tar -xzvf elasticsearch-7.7.1-linux-x86_64.tar.gz  

The training relies on two datasets: [ELI5](https://arxiv.org/abs/1907.09190), a processed version of the [r/explainlikeimfive](https://www.reddit.com/r/explainlikeimfive/) subreddit, and the [Wiki40b](https://www.aclweb.org/anthology/2020.lrec-1.297/) Wikipedia image.

Downloading ELI5 can take up to 72 hours since we need to to filter through all of the Reddit dumps for 8 years, so we suggest that you do that first (you will need a good download speed and about 10GB of disk space):

In [1]:
import nlp
eli5 = nlp.load_dataset('explainlikeimfive', name='LFQA_reddit', experimental=True)

This notebook is meant to be run from the `transformers/examples/eli5` folder in the [🤗transformers](https://github.com/huggingface/transformers), as all of the useful methods called here are compiled in the [eli5_utils.py](https://github.com/yjernite/transformers/blob/eli5_examples/examples/eli5/eli5_utils.py) script located there:

In [None]:
from eli5_utils import *

## Task and Data Description
<a id='task_description'></a>

Let's recap: we are interested in the task of Long Form Question Answering. As in other Question Answering tasks, the model is presented with a question, and is required to generate a natural language answer. Whereas a majority of QA datasets contain mostly **factoid** questions, where the answer, such as a date or the name of a single entity, can be expressed in a few words or single sentence, Long Form QA focuses on questions which call for an **explanation** consisting of a few sentences or a few paragraphs.

In order to teach a model to answer such questions, we use questions and answers written by Reddit users. Note that the `nlp.load_dataset` command above actually downloaded questions and their associated answers from the [r/explainlikeimfive](https://www.reddit.com/r/explainlikeimfive/), [r/askscience](https://www.reddit.com/r/askscience/), and [r/AskHistorians](https://www.reddit.com/r/AskHistorians/) subreddits. We focus here on the **ELI5/explainlikeimfive** part to train the system, as the examples there tend to be a little simpler.  

Let's look at one item from the test set:

In [2]:
eli5['test_eli5'][12345]

{'q_id': '8houtx',
 'title': 'Why does water heated to room temperature feel colder than the air around it?',
 'selftext': '',
 'document': '',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['dylcnfk', 'dylcj49'],
  'text': ["Water transfers heat more efficiently than air. When something feels cold it's because heat is being transferred from your skin to whatever you're touching. Since water absorbs the heat more readily than air, it feels colder.",
   "Air isn't as good at transferring heat compared to something like water or steel (sit on a room temperature steel bench vs. a room temperature wooden bench, and the steel one will feel more cold).\n\nWhen you feel cold, what you're feeling is heat being transferred out of you.  If there is no breeze, you feel a certain way.  If there's a breeze, you will get colder faster (because the moving air is pulling the heat away from you), and if you get into water, its quite good at pulling heat from you.   Get out of the water and ha

In order to answer, we want info from Wikipedia

In [None]:
wiki40b_snippets = nlp.load_dataset('wiki_snippets', name='wiki40b_en_100_0', experimental=True)['train']

### Note on Data and Biases
<a id='reddit_biases'></a>

PRoblems with reddit, hopefullu eli5/askscience is a bit better, still have much to do

## Retrieving Support Documents
<a id='retrieval'></a>

The first question is...

### Sparse Retrieval with ElasticSearch
<a id='elasticsearch'></a>

The traditional approach until...  

First, let's create a dense index

In [1]:
es_client = Elasticsearch([{'host': 'localhost', 'port': '9200'}])
if not es_client.indices.exists('wiki40b_snippets_100w'):
    make_es_index_snippets(es_client, wiki40b_snippets, index_name='wiki40b_snippets_100w')

NameError: name 'Elasticsearch' is not defined

Now let's test for one of the ELI5 questions:

In [6]:
question = eli5['test_eli5'][12345]['title']
doc, res_list = query_es_index(question, es_client, index_name='wiki40b_snippets_100w', n_results=10)

print(question)
print('-----\n')
for res in res_list:
    print("{}: \n  {}\n".format(
        res['article_title'],
        res['section_title'] if res['section_title'].strip() != '' else res['article_title']
    ))

Why does water heated to room temperature feel colder than the air around it?
-----

Salt fingering: 
  Salt fingering

Solar water heating: 
  Flat plate & Evacuated tube

Humidifier: 
  Fixed-installation humidifiers & Problems

Drake Landing Solar Community: 
  How it works & Energy centre

Diamond dust: 
  Characteristics & Formation

Effects of global warming on oceans: 
  Ocean currents

Mesoscale convective system: 
  Lake-effect snow

Thermal comfort: 
  Interplay of temperature and humidity

Honyaki: 
  Traditional process

Greywell Tunnel: 
  SSSI



### Training a Dense Retriever with ELI5 and in-batch Negatives
<a id='dense_train'></a>

Can we take advantage of our data to do better?

In [7]:
qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-768_A-12",
    from_file=None,
    device="cuda:0"
)

### Using a Trained Dense Retriever
<a id='dense_use'></a>

Can we take advantage of our data to do better?

In [3]:
qar_tokenizer, qar_model = make_qa_retriever_model(
    model_name="google/bert_uncased_L-8_H-768_A-12",
    from_file="retriever_models/eli5_retriever_model_l-8_h-768_b-512-512_9.pth",
    device="cuda:0"
)

In [4]:
faiss_res = faiss.StandardGpuResources()
wiki40b_passage_reps = np.memmap(
            'wiki40b_passages_reps_32_l-8_h-768_b-512-512.dat',
            dtype='float32', mode='r',
            shape=(wiki40b_snippets.num_rows, 128)
)

wiki40b_index_flat = faiss.IndexFlatIP(128)
wiki40b_gpu_index = faiss.index_cpu_to_gpu(faiss_res, 1, wiki40b_index_flat)
wiki40b_gpu_index.add(wiki40b_passage_reps)

In [10]:
question = eli5['test_eli5'][12345]['title']
doc, res_list = query_qa_dense_index(
    question,
    qar_model, qar_tokenizer,
    wiki40b_snippets, wiki40b_gpu_index,
    n_results=10
)

print(question)
print('-----\n')
for res in res_list:
    print("{}: \n  {}\n".format(
        res['article_title'],
        res['section_title'] if res['section_title'].strip() != '' else res['article_title']
    ))

Why does water heated to room temperature feel colder than the air around it?
-----

Fugacity: 
  History

Heat transfer: 
  Heat transfer in the human body & Evaporative cooling

Johan Sandström: 
  Sandström  Theorem

Thermal equilibrium: 
  Bodies prepared with separately uniform temperatures, then put into purely thermal communication with each other

Evaporative cooler: 
  Physical principles

Thermal contact conductance: 
  Factors influencing contact conductance & Contact pressure

Thermodynamic temperature: 
  The heat of phase changes

Temperature: 
  Local thermodynamic equilibrium & Bodies in thermodynamic equilibrium

Tail flick test: 
  Limitations

Latent heat: 
  Usage



### Retriever Evaluation
<a id='dense_eval'></a>

How can we evaluate the embedding model? Let's start by grabbing a couple of useful metrics from the `nlp` library:

In [124]:
%%capture --no-stdout
# load the ROUGE and BERTscore metrics from the nlp library
nlp_rouge = nlp.load_metric('rouge')
nlp_bertscore = nlp.load_metric('bertscore')

# takes a list of retrieved documents and a list of possible answers
# for a question and returns a measure of the lexical overlap between the
# passages and answer
def get_aggregate_rouge(res_list, answers):
    res = np.zeros((len(res_list), len(answers), 3))
    for i, hit in enumerate(res_list):
        for j, a in enumerate(answers):
            if len(hit.strip()) > 0 and len(a.strip()) > 0:
                # get Rouge-1 P/R/F for each passage/answer pair
                score = nlp_rouge.compute([hit], [a], rouge_types=['rouge1'])['rouge1'].mid
                res[i,j] = np.array([score.precision, score.recall, score.fmeasure])
    # average P/R/F rouge scores, then find best passage-answer match
    return res.mean(axis=2).max()

# Same with BERTscore metri which aligns contextual word embedings
def get_aggregate_bertscore(res_list, answers):
    res = np.zeros((len(res_list), len(answers), 3))
    for i, hit in enumerate(res_list):
        for j, a in enumerate(answers):
            if len(hit.strip()) > 0 and len(a.strip()) > 0:
                # get Rouge-1 P/R/F for each passage/answer pair
                score = nlp_bertscore.compute([hit], [a], lang='en')
                res[i,j] = np.array([score['precision'].item(), score['recall'].item(), score['f1'].item()])
    # average P/R/F rouge scores, then find best passage-answer match
    return res.mean(axis=2).max()

# Compare which retriever finds passages that have the most
# lexical overlap with the ELI5 answers
st_time = time()
tot_rg_sparse = 0.
tot_bs_sparse = 0.
tot_rg_dense = 0.
tot_bs_dense = 0.
valid_slice = eli5['validation_eli5'][:1000]
for i, (question, answers) in enumerate(zip(valid_slice['title'], valid_slice['answers'])):
    # get documents with sparse retriever
    _, sparse_res_list = query_es_index(
        question,
        es_client, index_name='wiki40b_snippets_100w',
        n_results=5
    )
    sparse_passages = [res['passage_text'] for res in sparse_res_list]
    if len(sparse_passages) == 0:
        sparse_passages = [question]
    tot_rg_sparse += get_aggregate_rouge(sparse_passages, answers['text'])
    tot_bs_sparse += get_aggregate_bertscore(sparse_passages, answers['text'])
    # get documents with dense retriever
    _, dense_res_list = query_qa_dense_index(
        question,
        qar_model, qar_tokenizer,
        wiki40b_snippets, wiki40b_gpu_index,
        n_results=5
    )
    dense_passages = [res['passage_text'] for res in dense_res_list]
    tot_rg_dense += get_aggregate_rouge(dense_passages, answers['text'])
    tot_bs_dense += get_aggregate_bertscore(dense_passages, answers['text'])
    # show average scores side by side
    if (i+1) % 10 == 0:
        print("{:03d} Sparse: RG-{:.4f} BS-{:.4f} | Dense: RG-{:.4f} BS-{:.4f} \t {:.2f}".format(
            i+1,
            tot_rg_sparse / (i+1), tot_bs_sparse / (i+1),
            tot_rg_dense / (i+1), tot_bs_dense / (i+1),
            time() - st_time
        ))

010 Sparse: RG-0.2652 BS-0.8053 | Dense: RG-0.2521 BS-0.8141 	 103.34
020 Sparse: RG-0.2657 BS-0.8068 | Dense: RG-0.2647 BS-0.8193 	 174.52
030 Sparse: RG-0.2631 BS-0.8052 | Dense: RG-0.2591 BS-0.8156 	 261.19
040 Sparse: RG-0.2623 BS-0.8049 | Dense: RG-0.2594 BS-0.8149 	 352.42
050 Sparse: RG-0.2660 BS-0.8060 | Dense: RG-0.2639 BS-0.8176 	 457.43
060 Sparse: RG-0.2698 BS-0.8053 | Dense: RG-0.2649 BS-0.8172 	 540.86
070 Sparse: RG-0.2684 BS-0.8058 | Dense: RG-0.2630 BS-0.8187 	 602.46
080 Sparse: RG-0.2671 BS-0.8062 | Dense: RG-0.2640 BS-0.8185 	 694.04
090 Sparse: RG-0.2646 BS-0.8063 | Dense: RG-0.2622 BS-0.8182 	 763.54
100 Sparse: RG-0.2627 BS-0.8058 | Dense: RG-0.2619 BS-0.8190 	 822.09
110 Sparse: RG-0.2646 BS-0.8056 | Dense: RG-0.2626 BS-0.8186 	 900.97
120 Sparse: RG-0.2673 BS-0.8055 | Dense: RG-0.2661 BS-0.8177 	 1013.28
130 Sparse: RG-0.2685 BS-0.8053 | Dense: RG-0.2678 BS-0.8175 	 1080.84
140 Sparse: RG-0.2660 BS-0.8050 | Dense: RG-0.2654 BS-0.8180 	 1380.19
150 Sparse: RG-0.

KeyboardInterrupt: 

In [123]:
sparse_res_list

[]

In [120]:
print("{:03d} Sparse: RG-{:.4f} BS-{:.4f} | Dense: RG-{:.4f} BS-{:.4f} \t {:.2f}".format(
            i+1,
            tot_rg_sparse / (i+1), tot_bs_sparse / (i+1),
            tot_rg_dense / (i+1), tot_bs_dense / (i+1),
            time() - st_time
        ))

199 Sparse: RG-0.2609 BS-0.8014 | Dense: RG-0.2634 BS-0.8135 	 1923.04


## Answer Generation Model
<a id='generation'></a>

Once we have a question and a document containing



In [2]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 16
        self.backward_freq = 8
        self.max_length = 1024
        self.print_freq = 100
        self.model_save_name = "seq2seq_models/bart_model"
        self.learning_rate = 2e-4
        self.num_epochs = 20

s2s_args = ArgumentsS2S()

In [3]:
qa_s2s_tokenizer, pre_model = make_qa_s2s_model(
    model_name="facebook/bart-large",
    from_file=None,
    device="cuda:0"
)
qa_s2s_model = torch.nn.DataParallel(pre_model)

In [4]:
eli5_train_docs = json.load(open('precomputed/eli5_train_precomputed_dense_docs.json'))
eli5_valid_docs = json.load(open('precomputed/eli5_valid_precomputed_dense_docs.json'))

s2s_train_dset = ELI5DatasetS2S(eli5['train_eli5'], document_cache=dict([(k, d) for k, d, src_ls in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5['validation_eli5'], document_cache=dict([(k, d) for k, d, src_ls in eli5_valid_docs]), training=False)

In [5]:
s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=s2s_args.num_epochs * math.ceil(len(s2s_train_dset) / s2s_args.batch_size)
)

In [None]:
for e in range(s2s_args.num_epochs):
    train_qa_s2s_epoch(
        qa_s2s_model,
        s2s_train_dset, qa_s2s_tokenizer,
        s2s_optimizer, s2s_scheduler,
        s2s_args, e
    )
    m_save_dict = {
        'model': qa_s2s_model.state_dict(),
        'optimizer': s2s_optimizer.state_dict(),
        'scheduler': s2s_scheduler.state_dict(),
    }
    print("Saving model {}".format(s2s_args.model_save_name))
    torch.save(m_save_dict, '{}_{}.pth'.format(s2s_args.model_save_name, e))



 0     0 of 36184 	 L: 4.835 	 -- 29.596
 0     1 of 36184 	 L: 4.687 	 -- 31.556
 0   100 of 36184 	 L: 4.478 	 -- 150.420
 0   200 of 36184 	 L: 3.755 	 -- 269.327
 0   300 of 36184 	 L: 3.469 	 -- 387.855
 0   400 of 36184 	 L: 3.360 	 -- 505.592
 0   500 of 36184 	 L: 3.311 	 -- 624.245
 0   600 of 36184 	 L: 3.274 	 -- 741.951
 0   700 of 36184 	 L: 3.242 	 -- 860.024
 0   800 of 36184 	 L: 3.231 	 -- 977.969
 0   900 of 36184 	 L: 3.214 	 -- 1096.494
 0  1000 of 36184 	 L: 3.230 	 -- 1214.687
 0  1100 of 36184 	 L: 3.227 	 -- 1332.183
 0  1200 of 36184 	 L: 3.186 	 -- 1450.259
 0  1300 of 36184 	 L: 3.206 	 -- 1568.458
 0  1400 of 36184 	 L: 3.209 	 -- 1686.349
 0  1500 of 36184 	 L: 3.184 	 -- 1804.035
 0  1600 of 36184 	 L: 3.204 	 -- 1922.431
 0  1700 of 36184 	 L: 3.190 	 -- 2040.502
 0  1800 of 36184 	 L: 3.174 	 -- 2158.630
 0  1900 of 36184 	 L: 3.181 	 -- 2276.222
 0  2000 of 36184 	 L: 3.187 	 -- 2393.720
 0  2100 of 36184 	 L: 3.192 	 -- 2511.511
 0  2200 of 36184 	 L: 

In [15]:
torch.cuda.empty_cache()

In [11]:
_ = qa_s2s_model.eval()
s2s_args.print_freq = 100
eval_qa_s2s_epoch(
        qa_s2s_model,
        s2s_valid_dset, qa_s2s_tokenizer,
        s2s_args
)

    0 of  2453 	 L: 3.521 	 -- 0.315
 1000 of  2453 	 L: 3.260 	 -- 319.746
 2000 of  2453 	 L: 3.264 	 -- 638.111
Total 	 L: 3.265 	 -- 782.534


In [29]:
eli5['validation_eli5'][11]

{'q_id': '20q8w1',
 'title': 'How do apps like soundhound and shazam know what song is playing?',
 'selftext': '',
 'document': '',
 'subreddit': 'explainlikeimfive',
 'answers': {'a_id': ['cg5r130'],
  'text': ['ELI5:\n\nThink about when you hear your parents, you can recognize their voice right? Or when you see a dog, you can recognize it\'s a dog in general. Now what kind of dog? You can typically recognize it\'s a chihuahua or, my fav, a golden retriever. How? Chihuahuas are small and annoying with short hair, whereas a golden retriever is cute, cuddly, friendly, with long hair (I may have some bias here).\n\nIn the same way, Shazam and Soundhound does that! They take a look at features of a song, like the pitch, tone, or waveform (the "shape" of the song) and try to match it to a song in their memory.'],
  'score': [2]},
 'title_urls': {'url': []},
 'selftext_urls': {'url': []},
 'answers_urls': {'url': []}}

In [25]:
print(qa_s2s_generate(
        s2s_valid_dset[11][0], qa_s2s_model.module, qa_s2s_tokenizer,
        num_answers=1,
        num_beams=8,
        min_len=64,
        max_len=256,
        max_input_length=1024,
        device="cuda:0"
    )[0])

Soundhound and Shazam don't "know" what song is playing, but they do have a database of songs that they can look at to find out what song they're listening to. 

When you play a song, the app listens to the song and compares it to the database that it knows what song it's listening to, and if it finds a match, it knows it's playing the song.

If it can't find the song, it doesn't know what's playing.


In [None]:
generated = []
st_time = time()
for i in range(2000):
    generated += [qa_s2s_generate(
        s2s_valid_dset[i][0], qa_s2s_model.module, qa_s2s_tokenizer,
        num_answers=1,
        num_beams=8,
        min_len=64,
        max_len=256,
        max_input_length=1024,
        device="cuda:0"
    )[0]]
    if i % 100 == 0:
        print(eli5['validation_eli5'].num_rows, i, time() - st_time)

In [37]:
def qda_difficulty(question_doc, answer):
    qd_words = dict([(w, True) for w in question_doc.lower().split()])
    recall = len([w for w in answer.lower().split() if w in qd_words]) / len(answer.split())
    return recall

In [32]:
s2s_train_dset[0]

('question: in football whats the point of wasting the first two plays with a rush - up the middle - not regular rush plays i get those context: <p> blitz count), so called because the blitzing player must insert the word "mississippi" between numbers so as not to allow the player to count ridiculously fast and effectively give the quarterback no time to throw. sometimes the two rules are combined, allowing one separate call of "blitz!" per set of 4 downs. the other option to handle a rush is to use an offensive lineman or center to block any pass rush. a line is rare in street, and the act of a center snapping to a quarterback is completely optional. most teams that use a line opt for 3 <p> of the increased appearance fees that such a bowl generates for the conference. jim vertuno of the associated press wrote "leach was upset officials disallowed two tech touchdowns in the third quarter. the first was overruled when video replay clearly showed the receiver let the ball hit the ground

In [47]:
recall_diff = [(i, qda_difficulty(*s2s_train_dset[i])) for i in range(10000)]

In [48]:
sorted(recall_diff, key=lambda x:x[1], reverse=True)[:10]

[(4885, 1.0),
 (4112, 0.9523809523809523),
 (8829, 0.9090909090909091),
 (9692, 0.9090909090909091),
 (3443, 0.9032258064516129),
 (2563, 0.9),
 (4930, 0.9),
 (6940, 0.9),
 (9267, 0.8928571428571429),
 (7644, 0.8888888888888888)]