In [1]:
import numpy as np
class Freq_Regularization():
    def predict_class(self, list_token_classes):
        '''
        UQA class = 0
        LM class = 1
        '''
        if len(list_token_classes) == 0:
            return np.random.choice(2, 1, p=[0.5, 0.5])[0]

        prob_uqa = sum(list_token_classes)/len(list_token_classes)
        prob_lm = 1 - prob_uqa
        return np.random.choice(2, 1, p=[prob_uqa, prob_lm])[0]


In [2]:
from BERTQG.token_generation import load_model, generate_token
#from Regularization_module import Regularizer_Discriminator
from itertools import groupby

list_wh = ['what', 'which', 'where', 'when', 'who', 'why', 'how', 'whom', 'whose']

class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization=None):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, self.device = load_model(bert_model, uqa_qg)
        self.regularization = Freq_Regularization()
        
    def generate_question(self, context: str, ans: str, ans_start: str, list_question_tokens: list = []) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for uqa
            token and 1 for lm token
        Output:
            - The next generated question token
        '''
        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        # 0 = UQA, 1 = LM
        list_token_classes = []
        qi = None
        
        list_banned_verbs = []
        list_question_failures = []
        
        # generation finished when [SEP] is created
        while not self.__finished_generation(qi):
            question_text = " ".join(list_question_tokens)
            
            # Generate the tokens and probs of the ith query token using the lm and uqa models
            lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(self.lm_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(self.uqa_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            
            print("lm tokens", self.__get_topk_tokens(lm_qi_probs))
            print("uqa_qi_probs", self.__get_topk_tokens(uqa_qi_probs))
            
            # Check if the question is corrupted
            
            list_lm_tokens = self.__get_topk_tokens(lm_qi_probs)
            list_uqa_tokens = self.__get_topk_tokens(uqa_qi_probs)
            corrupted_question = self.__corrupted_question(list_question_tokens[-1], list_lm_tokens, list_uqa_tokens)
            if corrupted_question:
                print("Restarting question!!!!")
                list_banned_verbs.append(list_question_tokens[1])
                # restart the question with a new verb becasue we reached a corrupted question
                list_question_tokens = [list_question_tokens[0]]
                list_question_failures.append(question_text)
                list_token_classes = []
                continue
            
            # Get token class to use
            qi_class = self.regularization.predict_class(list_token_classes)
            
            
            # Get the predicted token
            if qi_class == 1: # LM
                # clean prob distrib of lm
                lm_qi_token, lm_qi_probs = self.__clean_token_distrib(lm_qi_probs, list_question_tokens, list_banned_verbs)
                qi = lm_qi_token
            else: # UQA
                #clean the prob distrib of uqa
                uqa_qi_token, uqa_qi_probs = self.__clean_token_distrib(uqa_qi_probs, list_question_tokens, list_banned_verbs)
                qi = uqa_qi_token
                
            print("qi_class", qi_class)
            print("################")
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)
            
#             # for testing
#             print(list_question_tokens)
#             print(["UQA" if cls == 0 else "LM" for cls in list_token_classes])
#             print("\n")
        list_question_tokens = self.__remove_consecutive_repeated_tokens(list_question_tokens)
        print("Attempted Questions:", list_question_failures)
        return " ".join(list_question_tokens[:-1]) # without [SEP]
            
    def __finished_generation(self, question_token):
        return question_token =='[SEP]'
    
    def __prob_distrib2tokens(self, probs):
        return self.tokenizer.convert_ids_to_tokens(probs)
    
    def __get_topk_tokens(self, probs, topk=10):
        '''
        Input:
            - probs: probability distribution over vocab (output distribution)
            - topk: k to return
        Output:
            - list of tokens (top 1 to k)
        '''
        return self.tokenizer.convert_ids_to_tokens(reversed(np.argsort(probs.tolist())[-topk:]))

    def __ignore_probs(self, probs, idx_list):
        '''
        Input:
            - probs: probability distribution over vocab (output distribution)
            - idx_list: index list to ignore (to force it as -inf)
        Output:
            - new probs where probabilities of some tokens are ignored
        '''
        for idx in idx_list:
            probs[idx] = -float('inf')
        return probs

    def __clean_token_distrib(self, probs, list_tokens: list, list_banned_verbs: list):
        wh_generated = any([True for token in list_tokens if token.lower() in list_wh])
        list_to_remove = ['you', 'it', 'they', 'he', 'she', "i", 'we']
        list_to_remove.extend(list_banned_verbs)
        if wh_generated:
            # if the wh word was already geneerated, do no generate more wh word
            list_to_remove.extend(list_wh)
        ignore_idx_list = self.tokenizer.convert_tokens_to_ids(list_to_remove)
        probs = self.__ignore_probs(probs, ignore_idx_list)
        list_topk_tokens = self.__get_topk_tokens(probs)
        return list_topk_tokens[0], probs
       
        
    def __corrupted_question(self, prev_token: str, list_lm_tokens: list, list_uqa_tokens: list):
        '''
        If the output distrib. are like:
            lm tokens ['##n', '##ns', '##nn', '##na', '##ne', '##m', '##nne', '##nb', '##nt', '##nia']
            uqa_qi_probs ['##n', '##ns', '##nh', '##ne', '##na', '##nst', '##s', '##no', '##ness', '##ng']
        the question is corrupted and we should backtrack. Maybe a different verb is needed
        '''
        #take the top 10 tokens
        list_lm_tokens = list_lm_tokens[:10]
        list_uqa_tokens = list_uqa_tokens[:10]
        
        num_hash_lm = sum([1 for token in list_lm_tokens if "##" in token])
        num_hash_uqa = sum([1 for token in list_uqa_tokens if "##" in token])
        return num_hash_lm >= 7 and num_hash_uqa > 7 and prev_token in self.tokenizer.vocab
        
    def __remove_consecutive_repeated_tokens(self, list_tokens):
        '''
        Removes consecutive tokens.
        Sometimes the generated question is "when when did...",
        so we need to remove one when
        '''
        return [x[0] for x in groupby(list_tokens)]
    #assert remove_consecutive_repeated_tokens([1,1,1,1,1,1,2,3,4,4,5,1,2]) == [1, 2, 3, 4, 5, 1, 2]
    

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'
#regul_model = "./discri_model.bin"# model path here
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg)

05/11/2020 06:21:04 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/11/2020 06:21:06 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/11/2020 06:21:06 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


05/11/2020 06:21:08 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
05/11/2020 06:21:08 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
05/11/2020 06:21:08 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


# Testing

In [4]:
context = "The black bittern (Ixobrychus flavicollis) is a bittern of Old World origin, breeding in tropical Asia from Bangladesh, Pakistan, India, and Sri Lanka east to China, Indonesia, and Australia. It is mainly resident, but some northern birds migrate short distances."

In [5]:
ans = "Pakistan"
ans_start = context.find(ans)

In [6]:
QG.generate_question(context, ans, ans_start, ['where'])

lm tokens ['is', 'did', 'are', 'can', 'do', 'does', '?', 'where', 'it', 'will']
uqa_qi_probs ['is', 'are', 'can', 'breeding', 'in', 'else', 'do', 'was', 'to', 'and']
qi_class 0
################
lm tokens ['it', 'the', '?', 'its', 'found', 'this', 'that', 'located', 'from', 'a']
uqa_qi_probs ['the', 'breeding', 'in', 'your', 'this', 'an', 'a', 'bangladesh', 'tropical', 'that']
qi_class 1
################
lm tokens ['bitter', '?', 'black', 'name', 'world', 'red', 'common', 'largest', 'most', 'burma']
uqa_qi_probs ['black', 'tropical', 'breeding', 'blue', 'white', 'african', 'red', 'negro', 'southern', 'south']
qi_class 1
################
lm tokens ['##n', '##ns', '##nn', '##na', '##ne', '##m', '##nne', '##nb', '##nt', '##nia']
uqa_qi_probs ['##n', '##ns', '##nh', '##ne', '##na', '##nst', '##s', '##no', '##ness', '##ng']
Restarting question!!!!
lm tokens ['is', 'did', 'are', 'can', 'do', 'does', '?', 'where', 'it', 'will']
uqa_qi_probs ['is', 'are', 'can', 'breeding', 'in', 'else', 'do', 

'where did breeding occur in tropical asia ?'