In [1]:
import numpy as np
class Freq_Regularization():
    def predict_class(self, list_token_classes):
        '''
        UQA class = 0
        LM class = 1
        '''
        if len(list_token_classes) == 0:
            return np.random.choice(2, 1, p=[0.5, 0.5])[0]

        prob_uqa = sum(list_token_classes)/len(list_token_classes)
        prob_lm = 1 - prob_uqa
        return np.random.choice(2, 1, p=[prob_uqa, prob_lm])[0]


In [45]:
from BERTQG.token_generation import load_model, generate_token
#from Regularization_module import Regularizer_Discriminator
from itertools import groupby
import torch
import nltk
from nltk.corpus import stopwords
# you might need nltk.download('stopwords')

class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization=None):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, self.device = load_model(bert_model, uqa_qg)
        self.regularization = Freq_Regularization()
        
        # get list of ids of ##n, ##na, ...
        list_bert_token_vocab  = self.tokenizer.convert_ids_to_tokens(list(range(len(self.tokenizer.vocab.keys()))))
        self.list_subtoken_ids = [token_id for token_id, token in 
                                  enumerate(list_bert_token_vocab) if "##" in token]

        self.list_wh = ['what', 'which', 'where', 'when', 'who', 'why', 'how', 'whom', 'whose']
        self.list_wh_ids = self.tokenizer.convert_tokens_to_ids(self.list_wh)

        self.list_pronouns = ['you', 'it', 'they', 'he', 'she', "i", 'we']
        self.list_pronouns_ids = self.tokenizer.convert_tokens_to_ids(self.list_pronouns)
        
        
    def generate_question(self, context: str, ans: str, ans_start: str, list_question_tokens: list = []) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for uqa
            token and 1 for lm token
        Output:
            - The next generated question token
        '''
        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        # 0 = UQA, 1 = LM
        list_token_classes = []
        qi = None
        
        # generation finished when [SEP] is created
        while not self.__finished_generation(qi):
            question_text = " ".join(list_question_tokens)
            
            ## Generate tokens ##
            lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(self.lm_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(self.uqa_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            
            print("lm tokens", self.__get_topk_tokens(lm_qi_probs))
            print("uqa_qi_probs", self.__get_topk_tokens(uqa_qi_probs))
            
            ## Regularization ##
            qi_class = self.regularization.predict_class(list_token_classes)
            
            
            ## Postprocessing ##
            prev_token = list_question_tokens[-1]
            num_generated_tokens = len(list_question_tokens)
            if qi_class == 1:
                lm_qi_probs = self.__postprocessing(lm_qi_probs, num_generated_tokens, prev_token)
                print("updated lm tokens", self.__get_topk_tokens(lm_qi_probs))
                qi = self.__get_topk_tokens(lm_qi_probs, 1)[0]  
            else:
                uqa_qi_probs = self.__postprocessing(uqa_qi_probs, num_generated_tokens, prev_token)
                print("updated uqa_qi_probs", self.__get_topk_tokens(uqa_qi_probs))
                qi = self.__get_topk_tokens(uqa_qi_probs, 1)[0]  
                
            print("qi_class", qi_class)
            print("################")
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)
            
#             # for testing
#             print(list_question_tokens)
#             print(["UQA" if cls == 0 else "LM" for cls in list_token_classes])
#             print("\n")


        list_question_tokens = self.__remove_consecutive_repeated_tokens(list_question_tokens)
        return " ".join(list_question_tokens[:-1]) # without [SEP]
            
    def __finished_generation(self, question_token):
        return question_token =='[SEP]'
    
    def __prob_distrib2tokens(self, probs):
        return self.tokenizer.convert_ids_to_tokens(probs)
    
    def __postprocessing(self, tensor_probs: torch.tensor, num_generated_tokens: int, prev_token: str):
        '''
        Postprocessing to remove the noise of the output distribution
        1) Remove subwords like ##n. In most of the cases it is benefitial because 
        the QG generates ##n with a very high prob. But in same cases QG also needs to generate subwords
        like Fuji ##mori
        2) Remove wh word. Only needed for the first token of the question
        3) Remove pronouns. SQuAD-like questions do not need "you" or "I"
        4) Avoid generating short questions like "what is?" or what is the?
        '''
        
        # 1) remove ##n
        tensor_probs = self.__remove_subwords(tensor_probs)
        
        # 2) Remove wh words
        tensor_probs = self.__remove_wh_words(tensor_probs)
        
        # 3) Remove pronouns. I, you, he, she, we, they are not used in SQuAD-like questions (factual)
        tensor_probs = self.__remove_pronouns(tensor_probs)
        
          
        # 4) avoid generating short questions (at least we need 3 tokens)
        predicted_ids = reversed(np.argsort(tensor_probs)[-1:]).tolist()
        predicted_token = self.__get_topk_tokens(tensor_probs, 1)[0]     
        if ( (num_generated_tokens == 2 and predicted_token == "?") or 
             (num_generated_tokens == 3 and predicted_token == "?" and prev_token in set(stopwords.words('english'))):
            tensor_probs = self.__ignore_probs(tensor_probs, predicted_ids)
        return tensor_probs 
        
    def __remove_subwords(self, tensor_probs):
        return self.__ignore_probs(tensor_probs, self.list_subtoken_ids)
            
    def __remove_wh_words(self, tensor_probs):
        return self.__ignore_probs(tensor_probs, self.list_wh_ids)

    def __remove_pronouns(self, tensor_probs):
        return self.__ignore_probs(tensor_probs, self.list_pronouns_ids)
            
    def __get_topk_tokens(self, probs, topk=10):
        '''
        Input:
            - probs: probability distribution over vocab (output distribution)
            - topk: k to return
        Output:
            - list of tokens (top 1 to k)
        '''
        return self.tokenizer.convert_ids_to_tokens(reversed(np.argsort(probs.tolist())[-topk:]))

    def __ignore_probs(self, probs, idx_list):
        '''
        Input:
            - probs: probability distribution over vocab (output distribution)
            - idx_list: index list to ignore (to force it as -inf)
        Output:
            - new probs where probabilities of some tokens are ignored
        '''
        for idx in idx_list:
            probs[idx] = -float('inf')
        return probs
        
    def __remove_consecutive_repeated_tokens(self, list_tokens):
        '''
        Removes consecutive tokens.
        Sometimes the generated question is "when when did...",
        so we need to remove one when
        '''
        return [x[0] for x in groupby(list_tokens)]
    #assert remove_consecutive_repeated_tokens([1,1,1,1,1,1,2,3,4,4,5,1,2]) == [1, 2, 3, 4, 5, 1, 2]
    

SyntaxError: invalid syntax (<ipython-input-45-153cba79b36d>, line 116)

In [46]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'
#regul_model = "./discri_model.bin"# model path here
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg)

05/12/2020 13:43:24 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/12/2020 13:43:24 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/12/2020 13:43:24 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


05/12/2020 13:43:26 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/12/2020 13:43:27 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/12/2020 13:43:27 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


# Testing

In [42]:
context = "The black bittern (Ixobrychus flavicollis) is a bittern of Old World origin, breeding in tropical Asia from Bangladesh, Pakistan, India, and Sri Lanka east to China, Indonesia, and Australia. It is mainly resident, but some northern birds migrate short distances."

In [43]:
ans = "Pakistan"
ans_start = context.find(ans)

In [44]:
QG.generate_question(context, ans, ans_start, ['where'])

lm tokens ['is', 'did', 'are', 'can', 'do', 'does', '?', 'where', 'it', 'will']
uqa_qi_probs ['is', 'are', 'can', 'breeding', 'in', 'else', 'do', 'was', 'to', 'and']
updated lm tokens ['is', 'did', 'are', 'can', 'do', 'does', '?', 'will', 'was', 'the']
qi_class 1
################
lm tokens ['it', 'the', '?', 'its', 'found', 'this', 'that', 'located', 'from', 'a']
uqa_qi_probs ['the', 'breeding', 'in', 'your', 'this', 'an', 'a', 'bangladesh', 'tropical', 'that']
updated uqa_qi_probs ['the', 'breeding', 'in', 'your', 'this', 'an', 'a', 'bangladesh', 'tropical', 'that']
qi_class 0
################
lm tokens ['bitter', '?', 'black', 'name', 'world', 'red', 'common', 'largest', 'most', 'burma']
uqa_qi_probs ['black', 'tropical', 'breeding', 'blue', 'white', 'african', 'red', 'negro', 'southern', 'south']
updated lm tokens ['bitter', '?', 'black', 'name', 'world', 'red', 'common', 'largest', 'most', 'burma']
qi_class 1
################
lm tokens ['##n', '##ns', '##nn', '##na', '##ne', '##m',

'where is the bitter of tropical asia ?'