In [1]:
import numpy as np
class Freq_Regularization():
    def predict_class(self, list_token_classes):
        '''
        UQA class = 0
        LM class = 1
        '''
        if len(list_token_classes) == 0:
            return np.random.choice(2, 1, p=[0.5, 0.5])[0]

        prob_uqa = sum(list_token_classes)/len(list_token_classes)
        prob_lm = 1 - prob_uqa
        return np.random.choice(2, 1, p=[prob_uqa, prob_lm])[0]


In [125]:
from BERTQG.token_generation import load_model, generate_token
from Regularization_module import Regularizer_Discriminator
from itertools import groupby


class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization=None):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, self.device = load_model(bert_model, uqa_qg)
        #self.regularization = Freq_Regularization()
        self.regularization = Regularizer_Discriminator(regularization)
        
    def generate_question(self, context: str, ans: str, ans_start: str, list_question_tokens: list = []) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens. This is given for testing.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for uqa
            token and 1 for lm token
        Output:
            - The next generated question token
        '''
        # contains the tokens of the generated question
        if len(list_question_tokens) == 0:
            list_question_tokens = []
            list_token_classes = [] # 0 = UQA, 1 = LM
            list_qi_idx = [] 
            list_qi_probs = []
        else: # for testing
            len_initial_tokens = len(list_question_tokens)
            list_token_classes = [-1] * len_initial_tokens # 0 = UQA, 1 = LM
            list_qi_idx = [-1] * len_initial_tokens
            list_qi_probs = [-1] * len_initial_tokens

        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        qi = qi_idx = qi_probs = None
        max_legnth = 50
        # generation finished when [SEP] is created
        while not self.__finished_generation(qi):
            question_text = " ".join(list_question_tokens)
            
            # Generate the toknes and probs of the ith query token using the lm and uqa models
            lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(self.lm_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(self.uqa_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            
            # Get token class to use
            #qi_class = self.regularization.predict_class(list_token_classes)
            qi_class= self.regularization.predict_class(context,ans,ans_start,question_text,list_question_tokens)
            # Get the predicted token
            if qi_class == 1: # LM
                qi = lm_qi_token
                qi_idx = lm_qi_idx
                qi_probs = lm_qi_probs
            else: # UQA
                qi = uqa_qi_token
                qi_idx = uqa_qi_idx
                qi_probs = uqa_qi_probs
            
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)
            list_qi_idx.append(qi_idx)
            list_qi_probs.append(qi_probs)
            
#             # for testing
#             print(list_question_tokens)
#             print(["UQA" if cls == 0 else "LM" for cls in list_token_classes])
#             print("\n")
            if (len(list_question_tokens) > max_legnth):
              break
        
        # indices to keep
        list_idx = self.__remove_consecutive_repeated_tokens(list_question_tokens)

        # without [SEP] -> [:-1]
        list_question_tokens = [list_question_tokens[idx] for idx in list_idx][:-1]
        list_token_classes = [list_token_classes[idx] for idx in list_idx][:-1]
        list_qi_idx = [list_qi_idx[idx] for idx in list_idx][:-1]
        list_qi_probs = [list_qi_probs[idx] for idx in list_idx][:-1]
        
        assert len(list_question_tokens) == len(list_token_classes) == len(list_qi_idx) == len(list_qi_probs)

        return " ".join(list_question_tokens), list_token_classes, list_qi_idx, list_qi_probs
            
    def __finished_generation(self, question_token):
        return question_token =='[SEP]'
    
#     def __remove_consecutive_repeated_tokens(self, list_tokens):
#         '''
#         Removes consecutive tokens.
#         Sometimes the generated question is "when when did...",
#         so we need to remove one when
#         '''
#         return [x[0] for x in groupby(list_tokens)]
    #assert remove_consecutive_repeated_tokens([1,1,1,1,1,1,2,3,4,4,5,1,2]) == [1, 2, 3, 4, 5, 1, 2]
  
    def __remove_consecutive_repeated_tokens(self, list_tokens):
        '''
        Removes consecutive tokens.
        Sometimes the generated question is "when when did...",
        so we need to remove one when
        Output:
            - list of index that is not consecutive repeated
            - list of index to keep
        '''
        indices = range(len(list_tokens))
        return [list(group)[0][1] for key, group in groupby(zip(list_tokens, indices), lambda x: x[0])]
#     assert __remove_consecutive_repeated_tokens([1,1,1,1,1,1,2,3,4,4,5,1,2]) == [0, 6, 7, 8, 10, 11, 12]


In [126]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'
regul_model = 'models/discri_model_partial_perturb.bin' # model path here
#QG = QuestionGeneration(bert_model, lm_qg, uqa_qg)
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg, regul_model)

05/11/2020 16:05:13 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/11/2020 16:05:13 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/11/2020 16:05:13 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


05/11/2020 16:05:15 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/11/2020 16:05:15 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/11/2020 16:05:15 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


05/11/2020 16:05:18 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
05/11/2020 16:05:18 - INFO - transformers.configuration_utils -   Model config BertConfig {
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
   

# Testing

In [127]:
context = "The city has a proud history of theatre. Stephen Kemble of the famous Kemble family successfully managed the original Theatre Royal, Newcastle for fifteen years (1791â€“1806). He brought members of his famous acting family such as Sarah Siddons and John Kemble out of London to Newcastle. Stephen Kemble guided the theatre through many celebrated seasons. The original Theatre Royal in Newcastle was opened on 21 January 1788 and was located on Mosley Street. It was demolished to make way for Grey Street, where its replacement was built."
# question_text = "when did the original theatre royal open ?"
ans = "1788"
ans_start = context.find(ans)
assert context[ans_start:ans_start+len(ans)] == ans

In [128]:
question, history, indices, probs = QG.generate_question(context, ans, ans_start)

In [129]:
question

'when did the original theatre royal open on 21 january ?'

In [130]:
history

[0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1]

In [131]:
indices

[2043, 2106, 1996, 2434, 3004, 2548, 2330, 2006, 2538, 2254, 1029]

In [132]:
probs[-1][102]

0.00454629585146904

In [133]:
def get_top_k(probs, k):
    top_k_indices = [np.argsort(-np.array(probs[token_idx]))[:k].tolist() for token_idx in range(len(probs))]
    top_k_probs = [sorted(probs[token_idx], reverse=True)[:k] for token_idx in range(len(probs))]
    return top_k_indices, top_k_probs

In [134]:
top_k_indices, top_k_probs = get_top_k(probs, 10)
print(top_k_indices[-1])
print(top_k_probs[-1])

[1029, 1012, 13739, 1010, 2960, 15622, 13393, 1024, 2997, 1011]
[0.6433014273643494, 0.10247533023357391, 0.03994838520884514, 0.027439186349511147, 0.025517335161566734, 0.02007225900888443, 0.007837814278900623, 0.007773424033075571, 0.006851928308606148, 0.005206955596804619]


In [124]:
output = []
k = 10
for i in range(1000):
    top_k_indices, top_k_probs = get_top_k(probs, k)
    assert len(question.split()) == len(top_k_indices) == len(top_k_probs)
    paragraph = {
                'qid': i,
                'context': context,
                'question': question,
                'top_k_indices': top_k_indices,
                'top_k_probs': top_k_probs,
                'answers': ans,
                'answer_start': ans_start
                }
    output.append(paragraph)

In [None]:
output[0]

In [108]:
import json
with open('soft_target_test.json', 'w') as f:
    json.dump(output, f)