In [1]:
import numpy as np
class Freq_Regularization():
    def predict_class(self, list_token_classes):
        '''
        UQA class = 0
        LM class = 1
        '''
        if len(list_token_classes) == 0:
            return np.random.choice(2, 1, p=[0.5, 0.5])[0]

        prob_uqa = sum(list_token_classes)/len(list_token_classes)
        prob_lm = 1 - prob_uqa
        return np.random.choice(2, 1, p=[prob_uqa, prob_lm])[0]


In [2]:
from BERTQG.token_generation import load_model, generate_token
from Regularization_module import Regularizer_Discriminator
from itertools import groupby

class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization=None):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, self.device = load_model(bert_model, uqa_qg)
        #self.regularization = Freq_Regularization()
        self.regularization = Regularizer_Discriminator(regularization)
        
    def generate_question(self, context: str, ans: str, ans_start: str, wh_word: str, list_question_tokens: list = []) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens. This is given for testing.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for uqa
            token and 1 for lm token
        Output:
            - The next generated question token
        '''
        # contains the tokens of the generated question
        if len(list_question_tokens) == 0:
            list_question_tokens = []
            list_token_classes = [] # 0 = UQA, 1 = LM
            list_qi_idx = [] 
            list_qi_probs = []
        else: # for testing
            len_initial_tokens = len(list_question_tokens)
            list_token_classes = [-1] * len_initial_tokens # 0 = UQA, 1 = LM
            list_qi_idx = [-1] * len_initial_tokens
            list_qi_probs = [-1] * len_initial_tokens

        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        qi = qi_idx = qi_probs = None
        max_legnth = 50
        # generation finished when [SEP] is created
        while not self.__finished_generation(qi):
            question_text = " ".join(list_question_tokens).replace(' ##', '')
            
            # Get token class to use
            # qi_class = self.regularization.predict_class(list_token_classes)
            qi_class= self.regularization.predict_class(context,ans,ans_start,question_text,list_question_tokens)
            
            # Get the predicted token
            # Generate the toknes and probs of the ith query token using the lm and uqa models
            # penalize repetition token inside QG (need to input question history: list_qi_idx)
            if qi_class == 1: # LM
                qi, qi_idx, qi_probs = generate_token(self.lm_qg, self.tokenizer, self.device, wh_word, list_qi_idx, context, question_text, ans, ans_start)
            else: # UQA
                qi, qi_idx, qi_probs = generate_token(self.uqa_qg, self.tokenizer, self.device, wh_word, list_qi_idx, context, question_text, ans, ans_start)
    
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)
            list_qi_idx.append(qi_idx)
            list_qi_probs.append(qi_probs)
            
            if (len(list_question_tokens) > max_legnth):
                break
        
        # indices to keep
        list_idx = self.__remove_consecutive_repeated_tokens(list_question_tokens)

        # without [SEP] -> [:-1]
        list_question_tokens = [list_question_tokens[idx] for idx in list_idx][:-1]
        list_token_classes = [list_token_classes[idx] for idx in list_idx][:-1]
        list_qi_idx = [list_qi_idx[idx] for idx in list_idx][:-1]
        list_qi_probs = [list_qi_probs[idx] for idx in list_idx][:-1]
        
        assert len(list_question_tokens) == len(list_token_classes) == len(list_qi_idx) == len(list_qi_probs)

        return " ".join(list_question_tokens), list_token_classes, list_qi_idx, list_qi_probs
            
    def __finished_generation(self, question_token):
        return question_token =='[SEP]'
      
    def __remove_consecutive_repeated_tokens(self, list_tokens):
        '''
        Removes consecutive tokens.
        Sometimes the generated question is "when when did...",
        so we need to remove one when
        Output:
            - list of index that is not consecutive repeated
            - list of index to keep
        '''
        indices = range(len(list_tokens))
        return [list(group)[0][1] for key, group in groupby(zip(list_tokens, indices), lambda x: x[0])]
#     assert __remove_consecutive_repeated_tokens([1,1,1,1,1,1,2,3,4,4,5,1,2]) == [0, 6, 7, 8, 10, 11, 12]


05/13/2020 03:13:44 - INFO - transformers.file_utils -   PyTorch version 1.4.0 available.


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


05/13/2020 03:13:46 - INFO - transformers.file_utils -   TensorFlow version 2.1.0 available.


In [3]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'
regul_model = 'models/discri_model_partial_perturb.bin' # model path here
#QG = QuestionGeneration(bert_model, lm_qg, uqa_qg)
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg, regul_model)

05/13/2020 03:13:48 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/13/2020 03:13:49 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/13/2020 03:13:49 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


05/13/2020 03:13:51 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/13/2020 03:13:51 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/13/2020 03:13:51 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


05/13/2020 03:13:53 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
05/13/2020 03:13:53 - INFO - transformers.configuration_utils -   Model config BertConfig {
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
   

# Testing

In [127]:
context = "The city has a proud history of theatre. Stephen Kemble of the famous Kemble family successfully managed the original Theatre Royal, Newcastle for fifteen years (1791–1806). He brought members of his famous acting family such as Sarah Siddons and John Kemble out of London to Newcastle. Stephen Kemble guided the theatre through many celebrated seasons. The original Theatre Royal in Newcastle was opened on 21 January 1788 and was located on Mosley Street. It was demolished to make way for Grey Street, where its replacement was built."
# question_text = "when did the original theatre royal open ?"
ans = "1788"
ans_start = context.find(ans)
assert context[ans_start:ans_start+len(ans)] == ans

In [128]:
question, history, indices, probs = QG.generate_question(context, ans, ans_start)

In [129]:
question

'when did the original theatre royal open on 21 january ?'

In [130]:
history

[0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1]

In [131]:
indices

[2043, 2106, 1996, 2434, 3004, 2548, 2330, 2006, 2538, 2254, 1029]

In [132]:
probs[-1][102]

0.00454629585146904

In [133]:
def get_top_k(probs, k):
    top_k_indices = [np.argsort(-np.array(probs[token_idx]))[:k].tolist() for token_idx in range(len(probs))]
    top_k_probs = [sorted(probs[token_idx], reverse=True)[:k] for token_idx in range(len(probs))]
    return top_k_indices, top_k_probs

In [134]:
top_k_indices, top_k_probs = get_top_k(probs, 10)
print(top_k_indices[-1])
print(top_k_probs[-1])

[1029, 1012, 13739, 1010, 2960, 15622, 13393, 1024, 2997, 1011]
[0.6433014273643494, 0.10247533023357391, 0.03994838520884514, 0.027439186349511147, 0.025517335161566734, 0.02007225900888443, 0.007837814278900623, 0.007773424033075571, 0.006851928308606148, 0.005206955596804619]


In [124]:
output = []
k = 10
for i in range(1000):
    top_k_indices, top_k_probs = get_top_k(probs, k)
    assert len(question.split()) == len(top_k_indices) == len(top_k_probs)
    paragraph = {
                'qid': i,
                'context': context,
                'question': question,
                'top_k_indices': top_k_indices,
                'top_k_probs': top_k_probs,
                'answers': ans,
                'answer_start': ans_start
                }
    output.append(paragraph)

In [None]:
output[0]

In [108]:
import json
with open('soft_target_test.json', 'w') as f:
    json.dump(output, f)

# Dataset Generation

In [202]:
import json
from tqdm.notebook import tqdm

with open('data/UQA_finalQG_50k.json') as f:
    UQA_finalQG_50k = json.load(f)

In [203]:
len(UQA_finalQG_50k)

50000

In [204]:
UQA_finalQG_50k[0]

{'title': 'a5a90842b1c35606cdb2553f8ea14915864509ce',
 'paragraphs': [{'context': 'Their first concert as a group was made at the closing party of Melbourne’s Spanish Club on 17 June 2007. Although barely announced, word that an alleged member of TISM was unveiling a new project led to a large, expectant crowd assembling. From there, word of mouth spread, leading to heavy traffic on the band\'s nascent MySpace page, the creation of a fan website entitled The Root! Compendium, and growing demand for an album.The Root! Compendium - Archive - 17 June 2007. "After the gig I managed to catch up with DC Root who claimed that a CD\'s worth of material had been completed and was "ready to go,"" - Review by Adam.',
   'qas': [{'answers': [{'answer_start': 322, 'text': 'MySpace'}],
     'question': 'What',
     'NER_tag': 'ORG',
     'id': 0}]}]}

In [205]:
# dataset refinement (whitespace tokenize)
from transformers.tokenization_bert import whitespace_tokenize

# refine whitespace and then reset answer_start
for i in tqdm(range(len(UQA_finalQG_50k))):
    UQA_finalQG_50k[i]['paragraphs'][0]['context'] = " ".join(whitespace_tokenize(UQA_finalQG_50k[i]['paragraphs'][0]['context']))
    UQA_finalQG_50k[i]['paragraphs'][0]['qas'][0]['answers'][0]['text'] = " ".join(whitespace_tokenize(UQA_finalQG_50k[i]['paragraphs'][0]['qas'][0]['answers'][0]['text']))
    UQA_finalQG_50k[i]['paragraphs'][0]['qas'][0]['answers'][0]['answer_start'] = UQA_finalQG_50k[i]['paragraphs'][0]['context'].find(UQA_finalQG_50k[i]['paragraphs'][0]['qas'][0]['answers'][0]['text'])
    # only keep first answer
    UQA_finalQG_50k[i]['paragraphs'][0]['qas'] = [UQA_finalQG_50k[i]['paragraphs'][0]['qas'][0]]

HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))




In [206]:
cnt = 0
for idx, article in enumerate(tqdm(UQA_finalQG_50k)):
    for paragraph in article['paragraphs']:
        for question in paragraph['qas']:
            cnt += 1
cnt

HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))




50000

In [207]:
# data check
for idx, article in enumerate(tqdm(UQA_finalQG_50k)):
    for paragraph in article['paragraphs']:
        for question in paragraph['qas']:
            context = paragraph['context']
            ans = question['answers'][0]['text']
            ans_start = question['answers'][0]['answer_start']
#             print(idx, context[ans_start:ans_start+len(ans)], ans)
            assert context[ans_start:ans_start+len(ans)] == ans
            
            start_position = ans_start
            end_position = ans_start + len(ans)
#             actual_text = " ".join(context[start_position : (end_position + 1)])
            actual_text = context[start_position : (end_position + 1)]
            cleaned_answer_text = " ".join(whitespace_tokenize(ans))

            assert actual_text.find(cleaned_answer_text) != -1, (idx, actual_text, cleaned_answer_text)


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))




In [194]:
with open('data/UQA_finalQG_50k_refined.json', 'w') as f:
    json.dump(UQA_finalQG_50k, f)

In [208]:
with open('data/UQA_finalQG_50k_refined.json') as f:
    UQA_finalQG_50k = json.load(f)

In [209]:
def get_top_k(probs, k):
    top_k_indices = [np.argsort(-np.array(probs[token_idx]))[:k].tolist() for token_idx in range(len(probs))]
    top_k_probs = [sorted(probs[token_idx], reverse=True)[:k] for token_idx in range(len(probs))]
    return top_k_indices, top_k_probs

In [None]:
k = 10
num_data = 10000
output = []
for article in tqdm(UQA_finalQG_50k[:num_data]):
    for paragraph in article['paragraphs']:
        for question in paragraph['qas']:
            qid = question['id']
            context = paragraph['context']
            ans = question['answers'][0]['text']
            ans_start = question['answers'][0]['answer_start']
            wh_word = question['question']
#             print('context:', context)
#             print('ans:', ans)
#             print('ans_start:', ans_start)
            assert context[ans_start:ans_start+len(ans)] == ans

            question, history, indices, probs = QG.generate_question(context, ans, ans_start, wh_word)
#             print('question:', question.replace(' ##', ''))
#             print('histroy:', history)
#             print('indices:', indices)
#             print('\n')
            top_k_indices, top_k_probs = get_top_k(probs, k)
            assert len(question.split()) == len(top_k_indices) == len(top_k_probs)
            paragraph = {
                        'id': qid,
                        'context': context,
                        'question': question,
                        'top_k_indices': top_k_indices,
                        'top_k_probs': top_k_probs,
                        'answers': ans,
                        'answer_start': ans_start
                        }
            output.append(paragraph)



HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

In [None]:
with open('data/finalQG_train_10k.json', 'w') as f:
    json.dump(output, f)

In [None]:
# question: what is the traffic leading on the root ! compendium page ?
# question: where is the new economic center of the united states has become a major exporter ?
# question: what else is the purpose of the drug ?
# question: where is flowing west through the commune of terron ?
# question: what is flowing in the commune of terron ?
# question: where were the first theories of logic in china and burma ?
# question: when did environmental statement come from the environment of bp , bulwer island refinery ?
# question: what is impersonating a girl in the movie ?
# question: what was partly because of the influence of kaura mall ?
# question: who was the presence of henderson and read the draft ?
# question: what are some verbs that have a vowel change ?

In [None]:
# question: what is the traffic leading on the root ! compendium page ?
# question: where is the new economic center of the united states ?
# question: what else is the purpose of the drug that is being investigated ?
# question: where is flowing west through the commune of terron ?
# question: what is flowing in the commune of terron flows west through ?
# question: where were the first theories of logic in china and burma ?
# question: when did environmental statement come from the environment of bp , bulwer island refinery was released ?
# question: what is impersonating a girl in the movie ?
# question: what was partly because of the influence of kaura mall ?
# question: who was the presence of henderson and read the draft ?
# question: what are some verbs that have a vowel change in the form of ?

In [None]:
# question: what is the traffic leading on the root ! compendium page ?
# question: where is the new economic center of the united states ?
# question: what else is the purpose of the drug that is being investigated ?
# question: where is flowing west through the commune of terron ?
# question: what is flowing in the commune of terron flows west through ?
# question: where were the first theories of logic in china and burma ?
# question: when did environmental statement come from the environment of bp , bulwer island refinery was released ?
# question: what is impersonating a girl in the movie ?
# question: what was partly because of the influence of kaura mall ?
# question: who was the presence of henderson and read the draft ?
# question: what are some verbs that have a vowel change in the form of ?

In [None]:
# question: what is the traffic leading on the root ! compendium page ?
# question: where is the new economic center of china ' s economy ?
# question: what else is the purpose of the drug that is being investigated ?
# question: where is flowing west through the commune of terron ?
# question: what is flowing in the commune of terron flows west through ?
# question: where were the first theories of logic in china and india ?
# question: when did environmental statement come from the environment of bp , bulwer island refinery was released ?
# question: what is impersonating a girl ' s best friend ?
# question: what was partly because of the sikhs ' influence on ?
# question: who was the presence of henderson ' s draft and ?
# question: what are some verbs that have a vowel change in the form of ?

In [None]:
# question: what is the traffic leading on the root ! compendium page ?
# question: where is the new economic center of china ' s economy ?
# question: what else is the purpose of the drug that is being investigated ?
# question: where is flowing west through the commune of terron ?
# question: what is flowing in the commune of terron , ?
# question: where were the first theories of logic in china , ?
# question: when did environmental statement , bulwer island refinery take place ?
# question: what is impersonating a girl ' s best friend ?
# question: what was partly because of the sikhs , they were not a ?
# question: who was the presence of henderson , he read the draft ?
# question: what are some verbs that have a vowel change in the form of ?

In [38]:
top_k_indices, _ = get_top_k(probs, 10)
# top_k_indices[3]

In [64]:
QG.tokenizer.convert_ids_to_tokens(top_k_indices[3])

['ru', '?', 'river', 'first', 'most', 'name', 'commune', 'who', 'other', 'one']

In [None]:
question: what is the traffic leading what is the root ! com ##pen ##tagram ?
    
question: what is the traffic leading where did you get the idea that you are leading ?

In [15]:
UQA_finalQG_50k[0]

{'title': 'a5a90842b1c35606cdb2553f8ea14915864509ce',
 'paragraphs': [{'context': 'Their first concert as a group was made at the closing party of Melbourne’s Spanish Club on 17 June 2007. Although barely announced, word that an alleged member of TISM was unveiling a new project led to a large, expectant crowd assembling. From there, word of mouth spread, leading to heavy traffic on the band\'s nascent MySpace page, the creation of a fan website entitled The Root! Compendium, and growing demand for an album.The Root! Compendium - Archive - 17 June 2007. "After the gig I managed to catch up with DC Root who claimed that a CD\'s worth of material had been completed and was "ready to go,"" - Review by Adam.',
   'qas': [{'answers': [{'answer_start': 322, 'text': 'MySpace'}],
     'question': 'What',
     'NER_tag': 'ORG',
     'id': 0}]}]}