## Named-Entity QG Model

In [1]:
import warnings
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
warnings.filterwarnings("ignore", message=r"Passing", category=UserWarning)
import os
import torch
import pickle
import itertools
import pandas as pd
import pyarrow.feather as feather
import spacy
nlp = spacy.load("en_core_web_sm")
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [2]:
class QGPipeline:
    def __init__(
        self,
        qg_model_name: str,
        min_para_token_num: int,
        min_anssent_token_num: int,
        ans_ent_type_exclude_list: list
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(qg_model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(qg_model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
        self.min_para_token_num = min_para_token_num
        self.min_anssent_token_num = min_anssent_token_num
        self.ans_ent_type_exclude_list = ans_ent_type_exclude_list
        
    def __call__(self, inputs: str):
        nlp_inputs = nlp(" ".join(inputs.split()))
        output=[]
        if len(nlp_inputs) < self.min_para_token_num:
            return output
        context = nlp_inputs.text
        sents, sents_pos, answers_pos = self._extract_answers_by_NER(nlp_inputs,context)
        flat_answers_pos = list(itertools.chain(*answers_pos))
        if len(flat_answers_pos) == 0:
            return output
        qg_examples = self._prepare_inputs_for_qg_from_answers_hl(context, sents, sents_pos, answers_pos)
        qg_inputs = [example['source_text'] for example in qg_examples]
        questions = self._generate_questions(qg_inputs)
        que_list=[]
        for example, que in zip(qg_examples, questions):
            if que in que_list:
                continue
            que_list.append(que)
            output.append({'question':que, 'answer':example['answer'], 'ans_pos':example['answer_pos'], 'ans-sent_pos':example['answer-sent_pos']})
        return output
    
    def _extract_answers_by_NER(self, nlp_context, context):
        sents, sents_pos=[], []
        answers_pos=[]
        for sent in nlp_context.sents:
            sents.append(sent.text)
            sents_pos.append([sent.start_char, sent.end_char])
            sent_ans_pos=[]
            if len(sent)< self.min_anssent_token_num:
                answers_pos.append(sent_ans_pos)
            else:
                for ent in list(sent.ents):
                    if context[ent.start_char:ent.end_char]==ent.text:
                        if ent.label_ not in self.ans_ent_type_exclude_list:
                            sent_ans_pos.append([ent.start_char,ent.end_char])
                answers_pos.append(sent_ans_pos)
        return sents, sents_pos, answers_pos
    
    def _prepare_inputs_for_qg_from_answers_hl(self, context, sents, sents_pos, answers_pos):
        inputs = []
        for i, answer_pos in enumerate(answers_pos):
            if len(answer_pos) == 0: 
                continue
            for pos_tuple in answer_pos:
                sents_copy = sents[:]
                sent = sents_copy[i] #correct sent
                answer_check = context[pos_tuple[0]:pos_tuple[1]]
                ans_start_idx = pos_tuple[0]-sents_pos[i][0]
                ans_end_idx = pos_tuple[1]-sents_pos[i][0]
                answer_text=sent[ans_start_idx:ans_end_idx]
                assert answer_text==answer_check
                try:
                    sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_end_idx: ]}"
                    sents_copy[i] = sent
                    source_text = " ".join(sents_copy)
                    source_text = f"generate question: {source_text}" 
                    source_text = source_text + " </s>"
                    inputs.append({"answer": answer_text, "answer_pos": pos_tuple, "answer-sent_pos": sents_pos[i], "source_text": source_text})
                except:
                    continue
        return inputs

    def _tokenize(self,inputs):
        inputs = self.tokenizer(inputs,max_length=512,padding=True,truncation=True,return_tensors="pt")
        return inputs
    
    def _generate_questions(self, inputs):
        inputs = self._tokenize(inputs)
        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device), 
            attention_mask=inputs['attention_mask'].to(self.device), 
            max_length=32,
            num_beams=4,
        )
        questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
        return questions

## Question Generation

In [3]:
examples=pickle.load(open("data/examples.pickle", "rb"))

qg_model_name="valhalla/t5-base-qg-hl"
min_para_token_num = 30
min_anssent_token_num = 10
qg_model = QGPipeline(qg_model_name, min_para_token_num, min_anssent_token_num, ans_ent_type_exclude_list=[])
# The QG model in our work is a modified version based on "valhalla/t5-base-qg-hl", and the QGPipeline class here is based on "valhalla/t5-base-qg-hl".
# Other exisiting awesome qg models are also available in huggingface, for example, "mrm8488/t5-base-finetuned-question-generation-ap".
# The paragraphs that have less than min_para_token_num tokens are eliminated.
# The answers whose corresponding sentences have less than 10 tokens are eliminated.
# The answers whose corresponding NER types in ans_ent_type_exclude_list(["PERSON", "ORG", ...], but not used in our work) are eliminated.

In [4]:
total_results = []
for example in examples:
    para_id, para_pub, para_text = example
    results = qg_model(para_text)
    for r in results:
        r["para_id"] = para_id
    total_results.extend(results)
    
print(f"Generate {len(total_results)} questions.")

Generate 90 questions.


In [5]:
total_results_df = pd.DataFrame(total_results)
total_results_df.head()

Unnamed: 0,question,answer,ans_pos,ans-sent_pos,para_id
0,Which country's hard-line Revolutionary Guard ...,Iran,"[0, 4]","[0, 249]",1650014_0
1,What group said the death sentence against Sal...,Revolutionary Guard,"[17, 36]","[0, 249]",1650014_0
2,On what day did Iran's Revolutionary Guard say...,Saturday,"[45, 53]","[0, 249]",1650014_0
3,Who was sentenced to death for blasphemy again...,Salman Rushdie,"[86, 100]","[0, 249]",1650014_0
4,Who condemned Salman Rushdie?,Ayatollah Ruhollah Khomeini,"[125, 152]","[0, 249]",1650014_0


In [6]:
#store the raw QG results
total_results_df.to_feather("data/raw_results_After_2ndModule.feather")